[ir] Emit short-circuit as an `If` node
This Cl removes the `&&` and `||` logical binary nodes and replaces them
with a var declaration and if node.
Bug: tint:1925
Change-Id: I9f25411a9b9c909fa25f2f37cbd51181ac584acc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/130500
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc
index 19caf6f..fd6f500 100644
--- a/src/tint/ir/builder_impl.cc
+++ b/src/tint/ir/builder_impl.cc
@@ -778,7 +778,66 @@
return inst;
}
+// A short-circut needs special treatment. The short-circuit is decomposed into the relevant if
+// statements and declarations.
+utils::Result<Value*> BuilderImpl::EmitShortCircuit(const ast::BinaryExpression* expr) {
+ switch (expr->op) {
+ case ast::BinaryOp::kLogicalAnd:
+ case ast::BinaryOp::kLogicalOr:
+ break;
+ default:
+ TINT_ICE(IR, diagnostics_) << "invalid operation type for short-circut decomposition";
+ return utils::Failure;
+ }
+
+ auto lhs = EmitExpression(expr->lhs);
+ if (!lhs) {
+ return utils::Failure;
+ }
+
+ auto* ty = builder.ir.types.Get<type::Bool>();
+ auto* result_var =
+ builder.Declare(ty, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+ current_flow_block->instructions.Push(result_var);
+
+ auto* lhs_store = builder.Store(result_var, lhs.Get());
+ current_flow_block->instructions.Push(lhs_store);
+
+ auto* if_node = builder.CreateIf();
+ if_node->condition = lhs.Get();
+ BranchTo(if_node);
+
+ utils::Result<Value*> rhs;
+ {
+ FlowStackScope scope(this, if_node);
+
+ // If this is an `&&` then we only evaluate the RHS expression in the true block.
+ // If this is an `||` then we only evaluate the RHS expression in the false block.
+ if (expr->op == ast::BinaryOp::kLogicalAnd) {
+ current_flow_block = if_node->true_.target->As<Block>();
+ } else {
+ current_flow_block = if_node->false_.target->As<Block>();
+ }
+
+ rhs = EmitExpression(expr->rhs);
+ if (!rhs) {
+ return utils::Failure;
+ }
+ auto* rhs_store = builder.Store(result_var, rhs.Get());
+ current_flow_block->instructions.Push(rhs_store);
+
+ BranchTo(if_node->merge.target);
+ }
+ current_flow_block = if_node->merge.target->As<Block>();
+
+ return result_var;
+}
+
utils::Result<Value*> BuilderImpl::EmitBinary(const ast::BinaryExpression* expr) {
+ if (expr->op == ast::BinaryOp::kLogicalAnd || expr->op == ast::BinaryOp::kLogicalOr) {
+ return EmitShortCircuit(expr);
+ }
+
auto lhs = EmitExpression(expr->lhs);
if (!lhs) {
return utils::Failure;
@@ -803,12 +862,6 @@
case ast::BinaryOp::kXor:
inst = builder.Xor(ty, lhs.Get(), rhs.Get());
break;
- case ast::BinaryOp::kLogicalAnd:
- inst = builder.LogicalAnd(ty, lhs.Get(), rhs.Get());
- break;
- case ast::BinaryOp::kLogicalOr:
- inst = builder.LogicalOr(ty, lhs.Get(), rhs.Get());
- break;
case ast::BinaryOp::kEqual:
inst = builder.Equal(ty, lhs.Get(), rhs.Get());
break;
@@ -848,6 +901,10 @@
case ast::BinaryOp::kModulo:
inst = builder.Modulo(ty, lhs.Get(), rhs.Get());
break;
+ case ast::BinaryOp::kLogicalAnd:
+ case ast::BinaryOp::kLogicalOr:
+ TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled";
+ return utils::Failure;
case ast::BinaryOp::kNone:
TINT_ICE(IR, diagnostics_) << "missing binary operand type";
return utils::Failure;
diff --git a/src/tint/ir/builder_impl.h b/src/tint/ir/builder_impl.h
index 7b18492..58d4844 100644
--- a/src/tint/ir/builder_impl.h
+++ b/src/tint/ir/builder_impl.h
@@ -166,6 +166,11 @@
/// @returns the value storing the result if successful, utils::Failure otherwise
utils::Result<Value*> EmitUnary(const ast::UnaryOpExpression* expr);
+ /// Emits a short-circult binary expression
+ /// @param expr the binary expression
+ /// @returns the value storing the result if successful, utils::Failure otherwise
+ utils::Result<Value*> EmitShortCircuit(const ast::BinaryExpression* expr);
+
/// Emits a binary expression
/// @param expr the binary expression
/// @returns the value storing the result if successful, utils::Failure otherwise
diff --git a/src/tint/ir/builder_impl_test.cc b/src/tint/ir/builder_impl_test.cc
index 1e6a26d..bcba641 100644
--- a/src/tint/ir/builder_impl_test.cc
+++ b/src/tint/ir/builder_impl_test.cc
@@ -1773,16 +1773,33 @@
auto* expr = LogicalAnd(Call("my_func"), false);
WrapInFunction(expr);
- auto& b = CreateBuilder();
- InjectFlowBlock();
- auto r = b.EmitExpression(expr);
- ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
- ASSERT_TRUE(r);
+ auto r = Build();
+ ASSERT_TRUE(r) << Error();
+ auto m = r.Move();
- Disassembler d(b.builder.ir);
- d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
- EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func
-%2(bool) = log_and %1(bool), false
+ EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
+ %fn1 = block
+ ret true
+func_end
+
+%fn2 = func test_function
+ %fn3 = block
+ %1(bool) = call my_func
+ %2(bool) = var function read_write
+ store %2(bool), %1(bool)
+ branch %fn4
+
+ %fn4 = if %1(bool) [t: %fn5, f: %fn6, m: %fn7]
+ # true branch
+ %fn5 = block
+ store %2(bool), false
+ branch %fn7
+
+ # if merge
+ %fn7 = block
+ ret
+func_end
+
)");
}
@@ -1791,16 +1808,34 @@
auto* expr = LogicalOr(Call("my_func"), true);
WrapInFunction(expr);
- auto& b = CreateBuilder();
- InjectFlowBlock();
- auto r = b.EmitExpression(expr);
- ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
- ASSERT_TRUE(r);
+ auto r = Build();
+ ASSERT_TRUE(r) << Error();
+ auto m = r.Move();
- Disassembler d(b.builder.ir);
- d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
- EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func
-%2(bool) = log_or %1(bool), true
+ EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
+ %fn1 = block
+ ret true
+func_end
+
+%fn2 = func test_function
+ %fn3 = block
+ %1(bool) = call my_func
+ %2(bool) = var function read_write
+ store %2(bool), %1(bool)
+ branch %fn4
+
+ %fn4 = if %1(bool) [t: %fn5, f: %fn6, m: %fn7]
+ # true branch
+ # false branch
+ %fn6 = block
+ store %2(bool), true
+ branch %fn7
+
+ # if merge
+ %fn7 = block
+ ret
+func_end
+
)");
}
@@ -1955,22 +1990,39 @@
GreaterThan(2.5_f, Div(Call("my_func"), Mul(2.3_f, Call("my_func")))));
WrapInFunction(expr);
- auto& b = CreateBuilder();
- InjectFlowBlock();
- auto r = b.EmitExpression(expr);
- ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
- ASSERT_TRUE(r);
+ auto r = Build();
+ ASSERT_TRUE(r) << Error();
+ auto m = r.Move();
- Disassembler d(b.builder.ir);
- d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
- EXPECT_EQ(d.AsString(), R"(%1(f32) = call my_func
-%2(bool) = lt %1(f32), 2.0f
-%3(f32) = call my_func
-%4(f32) = call my_func
-%5(f32) = mul 2.29999995231628417969f, %4(f32)
-%6(f32) = div %3(f32), %5(f32)
-%7(bool) = gt 2.5f, %6(f32)
-%8(bool) = log_and %2(bool), %7(bool)
+ EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
+ %fn1 = block
+ ret 0.0f
+func_end
+
+%fn2 = func test_function
+ %fn3 = block
+ %1(f32) = call my_func
+ %2(bool) = lt %1(f32), 2.0f
+ %3(bool) = var function read_write
+ store %3(bool), %2(bool)
+ branch %fn4
+
+ %fn4 = if %2(bool) [t: %fn5, f: %fn6, m: %fn7]
+ # true branch
+ %fn5 = block
+ %4(f32) = call my_func
+ %5(f32) = call my_func
+ %6(f32) = mul 2.29999995231628417969f, %5(f32)
+ %7(f32) = div %4(f32), %6(f32)
+ %8(bool) = gt 2.5f, %7(f32)
+ store %3(bool), %8(bool)
+ branch %fn7
+
+ # if merge
+ %fn7 = block
+ ret
+func_end
+
)");
}
@@ -1980,15 +2032,21 @@
GreaterThan(2.5_f, Div(10_f, Mul(2.3_f, 9.4_f)))));
WrapInFunction(expr);
- auto& b = CreateBuilder();
- InjectFlowBlock();
- auto r = b.EmitExpression(expr);
- ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
- ASSERT_TRUE(r);
+ auto r = Build();
+ ASSERT_TRUE(r) << Error();
+ auto m = r.Move();
- Disassembler d(b.builder.ir);
- d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
- EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func, false
+ EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
+ %fn1 = block
+ ret true
+func_end
+
+%fn2 = func test_function
+ %fn3 = block
+ %1(bool) = call my_func, false
+ ret
+func_end
+
)");
}
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 2b8de0b..30da559 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -203,8 +203,10 @@
Indent() << "# true branch" << std::endl;
Walk(i->true_.target);
- Indent() << "# false branch" << std::endl;
- Walk(i->false_.target);
+ if (!i->false_.target->IsDead()) {
+ Indent() << "# false branch" << std::endl;
+ Walk(i->false_.target);
+ }
}
if (i->merge.target->IsConnected()) {