[ir] Emit short-circuit as an `If` node

This Cl removes the `&&` and `||` logical binary nodes and replaces them
with a var declaration and if node.

Bug: tint:1925
Change-Id: I9f25411a9b9c909fa25f2f37cbd51181ac584acc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/130500
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc
index 19caf6f..fd6f500 100644
--- a/src/tint/ir/builder_impl.cc
+++ b/src/tint/ir/builder_impl.cc
@@ -778,7 +778,66 @@
     return inst;
 }
 
+// A short-circut needs special treatment. The short-circuit is decomposed into the relevant if
+// statements and declarations.
+utils::Result<Value*> BuilderImpl::EmitShortCircuit(const ast::BinaryExpression* expr) {
+    switch (expr->op) {
+        case ast::BinaryOp::kLogicalAnd:
+        case ast::BinaryOp::kLogicalOr:
+            break;
+        default:
+            TINT_ICE(IR, diagnostics_) << "invalid operation type for short-circut decomposition";
+            return utils::Failure;
+    }
+
+    auto lhs = EmitExpression(expr->lhs);
+    if (!lhs) {
+        return utils::Failure;
+    }
+
+    auto* ty = builder.ir.types.Get<type::Bool>();
+    auto* result_var =
+        builder.Declare(ty, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite);
+    current_flow_block->instructions.Push(result_var);
+
+    auto* lhs_store = builder.Store(result_var, lhs.Get());
+    current_flow_block->instructions.Push(lhs_store);
+
+    auto* if_node = builder.CreateIf();
+    if_node->condition = lhs.Get();
+    BranchTo(if_node);
+
+    utils::Result<Value*> rhs;
+    {
+        FlowStackScope scope(this, if_node);
+
+        // If this is an `&&` then we only evaluate the RHS expression in the true block.
+        // If this is an `||` then we only evaluate the RHS expression in the false block.
+        if (expr->op == ast::BinaryOp::kLogicalAnd) {
+            current_flow_block = if_node->true_.target->As<Block>();
+        } else {
+            current_flow_block = if_node->false_.target->As<Block>();
+        }
+
+        rhs = EmitExpression(expr->rhs);
+        if (!rhs) {
+            return utils::Failure;
+        }
+        auto* rhs_store = builder.Store(result_var, rhs.Get());
+        current_flow_block->instructions.Push(rhs_store);
+
+        BranchTo(if_node->merge.target);
+    }
+    current_flow_block = if_node->merge.target->As<Block>();
+
+    return result_var;
+}
+
 utils::Result<Value*> BuilderImpl::EmitBinary(const ast::BinaryExpression* expr) {
+    if (expr->op == ast::BinaryOp::kLogicalAnd || expr->op == ast::BinaryOp::kLogicalOr) {
+        return EmitShortCircuit(expr);
+    }
+
     auto lhs = EmitExpression(expr->lhs);
     if (!lhs) {
         return utils::Failure;
@@ -803,12 +862,6 @@
         case ast::BinaryOp::kXor:
             inst = builder.Xor(ty, lhs.Get(), rhs.Get());
             break;
-        case ast::BinaryOp::kLogicalAnd:
-            inst = builder.LogicalAnd(ty, lhs.Get(), rhs.Get());
-            break;
-        case ast::BinaryOp::kLogicalOr:
-            inst = builder.LogicalOr(ty, lhs.Get(), rhs.Get());
-            break;
         case ast::BinaryOp::kEqual:
             inst = builder.Equal(ty, lhs.Get(), rhs.Get());
             break;
@@ -848,6 +901,10 @@
         case ast::BinaryOp::kModulo:
             inst = builder.Modulo(ty, lhs.Get(), rhs.Get());
             break;
+        case ast::BinaryOp::kLogicalAnd:
+        case ast::BinaryOp::kLogicalOr:
+            TINT_ICE(IR, diagnostics_) << "short circuit op should have already been handled";
+            return utils::Failure;
         case ast::BinaryOp::kNone:
             TINT_ICE(IR, diagnostics_) << "missing binary operand type";
             return utils::Failure;
diff --git a/src/tint/ir/builder_impl.h b/src/tint/ir/builder_impl.h
index 7b18492..58d4844 100644
--- a/src/tint/ir/builder_impl.h
+++ b/src/tint/ir/builder_impl.h
@@ -166,6 +166,11 @@
     /// @returns the value storing the result if successful, utils::Failure otherwise
     utils::Result<Value*> EmitUnary(const ast::UnaryOpExpression* expr);
 
+    /// Emits a short-circult binary expression
+    /// @param expr the binary expression
+    /// @returns the value storing the result if successful, utils::Failure otherwise
+    utils::Result<Value*> EmitShortCircuit(const ast::BinaryExpression* expr);
+
     /// Emits a binary expression
     /// @param expr the binary expression
     /// @returns the value storing the result if successful, utils::Failure otherwise
diff --git a/src/tint/ir/builder_impl_test.cc b/src/tint/ir/builder_impl_test.cc
index 1e6a26d..bcba641 100644
--- a/src/tint/ir/builder_impl_test.cc
+++ b/src/tint/ir/builder_impl_test.cc
@@ -1773,16 +1773,33 @@
     auto* expr = LogicalAnd(Call("my_func"), false);
     WrapInFunction(expr);
 
-    auto& b = CreateBuilder();
-    InjectFlowBlock();
-    auto r = b.EmitExpression(expr);
-    ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
-    ASSERT_TRUE(r);
+    auto r = Build();
+    ASSERT_TRUE(r) << Error();
+    auto m = r.Move();
 
-    Disassembler d(b.builder.ir);
-    d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
-    EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func
-%2(bool) = log_and %1(bool), false
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
+  %fn1 = block
+  ret true
+func_end
+
+%fn2 = func test_function
+  %fn3 = block
+  %1(bool) = call my_func
+  %2(bool) = var function read_write
+  store %2(bool), %1(bool)
+  branch %fn4
+
+  %fn4 = if %1(bool) [t: %fn5, f: %fn6, m: %fn7]
+    # true branch
+    %fn5 = block
+    store %2(bool), false
+    branch %fn7
+
+  # if merge
+  %fn7 = block
+  ret
+func_end
+
 )");
 }
 
@@ -1791,16 +1808,34 @@
     auto* expr = LogicalOr(Call("my_func"), true);
     WrapInFunction(expr);
 
-    auto& b = CreateBuilder();
-    InjectFlowBlock();
-    auto r = b.EmitExpression(expr);
-    ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
-    ASSERT_TRUE(r);
+    auto r = Build();
+    ASSERT_TRUE(r) << Error();
+    auto m = r.Move();
 
-    Disassembler d(b.builder.ir);
-    d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
-    EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func
-%2(bool) = log_or %1(bool), true
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
+  %fn1 = block
+  ret true
+func_end
+
+%fn2 = func test_function
+  %fn3 = block
+  %1(bool) = call my_func
+  %2(bool) = var function read_write
+  store %2(bool), %1(bool)
+  branch %fn4
+
+  %fn4 = if %1(bool) [t: %fn5, f: %fn6, m: %fn7]
+    # true branch
+    # false branch
+    %fn6 = block
+    store %2(bool), true
+    branch %fn7
+
+  # if merge
+  %fn7 = block
+  ret
+func_end
+
 )");
 }
 
@@ -1955,22 +1990,39 @@
                             GreaterThan(2.5_f, Div(Call("my_func"), Mul(2.3_f, Call("my_func")))));
     WrapInFunction(expr);
 
-    auto& b = CreateBuilder();
-    InjectFlowBlock();
-    auto r = b.EmitExpression(expr);
-    ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
-    ASSERT_TRUE(r);
+    auto r = Build();
+    ASSERT_TRUE(r) << Error();
+    auto m = r.Move();
 
-    Disassembler d(b.builder.ir);
-    d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
-    EXPECT_EQ(d.AsString(), R"(%1(f32) = call my_func
-%2(bool) = lt %1(f32), 2.0f
-%3(f32) = call my_func
-%4(f32) = call my_func
-%5(f32) = mul 2.29999995231628417969f, %4(f32)
-%6(f32) = div %3(f32), %5(f32)
-%7(bool) = gt 2.5f, %6(f32)
-%8(bool) = log_and %2(bool), %7(bool)
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
+  %fn1 = block
+  ret 0.0f
+func_end
+
+%fn2 = func test_function
+  %fn3 = block
+  %1(f32) = call my_func
+  %2(bool) = lt %1(f32), 2.0f
+  %3(bool) = var function read_write
+  store %3(bool), %2(bool)
+  branch %fn4
+
+  %fn4 = if %2(bool) [t: %fn5, f: %fn6, m: %fn7]
+    # true branch
+    %fn5 = block
+    %4(f32) = call my_func
+    %5(f32) = call my_func
+    %6(f32) = mul 2.29999995231628417969f, %5(f32)
+    %7(f32) = div %4(f32), %6(f32)
+    %8(bool) = gt 2.5f, %7(f32)
+    store %3(bool), %8(bool)
+    branch %fn7
+
+  # if merge
+  %fn7 = block
+  ret
+func_end
+
 )");
 }
 
@@ -1980,15 +2032,21 @@
                                             GreaterThan(2.5_f, Div(10_f, Mul(2.3_f, 9.4_f)))));
     WrapInFunction(expr);
 
-    auto& b = CreateBuilder();
-    InjectFlowBlock();
-    auto r = b.EmitExpression(expr);
-    ASSERT_THAT(b.Diagnostics(), testing::IsEmpty());
-    ASSERT_TRUE(r);
+    auto r = Build();
+    ASSERT_TRUE(r) << Error();
+    auto m = r.Move();
 
-    Disassembler d(b.builder.ir);
-    d.EmitBlockInstructions(b.current_flow_block->As<ir::Block>());
-    EXPECT_EQ(d.AsString(), R"(%1(bool) = call my_func, false
+    EXPECT_EQ(Disassemble(m), R"(%fn0 = func my_func
+  %fn1 = block
+  ret true
+func_end
+
+%fn2 = func test_function
+  %fn3 = block
+  %1(bool) = call my_func, false
+  ret
+func_end
+
 )");
 }
 
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 2b8de0b..30da559 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -203,8 +203,10 @@
                 Indent() << "# true branch" << std::endl;
                 Walk(i->true_.target);
 
-                Indent() << "# false branch" << std::endl;
-                Walk(i->false_.target);
+                if (!i->false_.target->IsDead()) {
+                    Indent() << "# false branch" << std::endl;
+                    Walk(i->false_.target);
+                }
             }
 
             if (i->merge.target->IsConnected()) {