Add break-if support.

This CL adds support for `break-if` to Tint.

Bug: tint:1633, tint:1451
Change-Id: I30dfd62a3e09255624ff76ebe0cdd3a3c7cf9c5f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/106420
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: dan sinclair <dsinclair@google.com>
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index ab28bf7..fc9de73 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -712,6 +712,16 @@
     return true;
 }
 
+bool GeneratorImpl::EmitBreakIf(const ast::BreakIfStatement* b) {
+    auto out = line();
+    out << "if (";
+    if (!EmitExpression(out, b->condition)) {
+        return false;
+    }
+    out << ") { break; }";
+    return true;
+}
+
 bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
     auto* call = builder_.Sem().Get<sem::Call>(expr);
     auto* target = call->Target();
@@ -2616,6 +2626,7 @@
         [&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
         [&](const ast::BlockStatement* b) { return EmitBlock(b); },
         [&](const ast::BreakStatement* b) { return EmitBreak(b); },
+        [&](const ast::BreakIfStatement* b) { return EmitBreakIf(b); },
         [&](const ast::CallStatement* c) {
             auto out = line();
             if (!EmitCall(out, c->expr)) {
diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h
index 2cd5d5d..889b406 100644
--- a/src/tint/writer/glsl/generator_impl.h
+++ b/src/tint/writer/glsl/generator_impl.h
@@ -139,6 +139,10 @@
     /// @param stmt the statement to emit
     /// @returns true if the statement was emitted successfully
     bool EmitBreak(const ast::BreakStatement* stmt);
+    /// Handles a break-if statement
+    /// @param stmt the statement to emit
+    /// @returns true if the statement was emitted successfully
+    bool EmitBreakIf(const ast::BreakIfStatement* stmt);
     /// Handles generating a call expression
     /// @param out the output of the expression stream
     /// @param expr the call expression
diff --git a/src/tint/writer/glsl/generator_impl_loop_test.cc b/src/tint/writer/glsl/generator_impl_loop_test.cc
index e638eb3..7aa94be 100644
--- a/src/tint/writer/glsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/glsl/generator_impl_loop_test.cc
@@ -65,6 +65,31 @@
 )");
 }
 
+TEST_F(GlslGeneratorImplTest_Loop, Emit_LoopWithContinuing_BreakIf) {
+    Func("a_statement", {}, ty.void_(), {});
+
+    auto* body = Block(create<ast::DiscardStatement>());
+    auto* continuing = Block(CallStmt(Call("a_statement")), BreakIf(true));
+    auto* l = Loop(body, continuing);
+
+    Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+         utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+    GeneratorImpl& gen = Build();
+
+    gen.increment_indent();
+
+    ASSERT_TRUE(gen.EmitStatement(l)) << gen.error();
+    EXPECT_EQ(gen.result(), R"(  while (true) {
+    discard;
+    {
+      a_statement();
+      if (true) { break; }
+    }
+  }
+)");
+}
+
 TEST_F(GlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
     Func("a_statement", {}, ty.void_(), {});
 
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index 1dc3b10..e42274d 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -942,6 +942,16 @@
     return true;
 }
 
+bool GeneratorImpl::EmitBreakIf(const ast::BreakIfStatement* b) {
+    auto out = line();
+    out << "if (";
+    if (!EmitExpression(out, b->condition)) {
+        return false;
+    }
+    out << ") { break; }";
+    return true;
+}
+
 bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
     auto* call = builder_.Sem().Get<sem::Call>(expr);
     auto* target = call->Target();
@@ -3591,6 +3601,9 @@
         [&](const ast::BreakStatement* b) {  //
             return EmitBreak(b);
         },
+        [&](const ast::BreakIfStatement* b) {  //
+            return EmitBreakIf(b);
+        },
         [&](const ast::CallStatement* c) {  //
             auto out = line();
             if (!EmitCall(out, c->expr)) {
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index faef168..2742f00 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -125,6 +125,10 @@
     /// @param stmt the statement to emit
     /// @returns true if the statement was emitted successfully
     bool EmitBreak(const ast::BreakStatement* stmt);
+    /// Handles a break-if statement
+    /// @param stmt the statement to emit
+    /// @returns true if the statement was emitted successfully
+    bool EmitBreakIf(const ast::BreakIfStatement* stmt);
     /// Handles generating a call expression
     /// @param out the output of the expression stream
     /// @param expr the call expression
diff --git a/src/tint/writer/hlsl/generator_impl_loop_test.cc b/src/tint/writer/hlsl/generator_impl_loop_test.cc
index 238fdd2..3d8219b 100644
--- a/src/tint/writer/hlsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_loop_test.cc
@@ -65,6 +65,31 @@
 )");
 }
 
+TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing_BreakIf) {
+    Func("a_statement", {}, ty.void_(), {});
+
+    auto* body = Block(create<ast::DiscardStatement>());
+    auto* continuing = Block(CallStmt(Call("a_statement")), BreakIf(true));
+    auto* l = Loop(body, continuing);
+
+    Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+         utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+    GeneratorImpl& gen = Build();
+
+    gen.increment_indent();
+
+    ASSERT_TRUE(gen.EmitStatement(l)) << gen.error();
+    EXPECT_EQ(gen.result(), R"(  while (true) {
+    discard;
+    {
+      a_statement();
+      if (true) { break; }
+    }
+  }
+)");
+}
+
 TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
     Func("a_statement", {}, ty.void_(), {});
 
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index c7c7b63..28478b8 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -660,6 +660,16 @@
     return true;
 }
 
+bool GeneratorImpl::EmitBreakIf(const ast::BreakIfStatement* b) {
+    auto out = line();
+    out << "if (";
+    if (!EmitExpression(out, b->condition)) {
+        return false;
+    }
+    out << ") { break; }";
+    return true;
+}
+
 bool GeneratorImpl::EmitCall(std::ostream& out, const ast::CallExpression* expr) {
     auto* call = program_->Sem().Get<sem::Call>(expr);
     auto* target = call->Target();
@@ -2433,6 +2443,9 @@
         [&](const ast::BreakStatement* b) {  //
             return EmitBreak(b);
         },
+        [&](const ast::BreakIfStatement* b) {  //
+            return EmitBreakIf(b);
+        },
         [&](const ast::CallStatement* c) {  //
             auto out = line();
             if (!EmitCall(out, c->expr)) {  //
diff --git a/src/tint/writer/msl/generator_impl.h b/src/tint/writer/msl/generator_impl.h
index 5e16fbe..188bfea 100644
--- a/src/tint/writer/msl/generator_impl.h
+++ b/src/tint/writer/msl/generator_impl.h
@@ -129,6 +129,10 @@
     /// @param stmt the statement to emit
     /// @returns true if the statement was emitted successfully
     bool EmitBreak(const ast::BreakStatement* stmt);
+    /// Handles a break-if statement
+    /// @param stmt the statement to emit
+    /// @returns true if the statement was emitted successfully
+    bool EmitBreakIf(const ast::BreakIfStatement* stmt);
     /// Handles generating a call expression
     /// @param out the output of the expression stream
     /// @param expr the call expression
diff --git a/src/tint/writer/msl/generator_impl_loop_test.cc b/src/tint/writer/msl/generator_impl_loop_test.cc
index 274ee941..1dd5430 100644
--- a/src/tint/writer/msl/generator_impl_loop_test.cc
+++ b/src/tint/writer/msl/generator_impl_loop_test.cc
@@ -65,6 +65,31 @@
 )");
 }
 
+TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing_BreakIf) {
+    Func("a_statement", {}, ty.void_(), {});
+
+    auto* body = Block(create<ast::DiscardStatement>());
+    auto* continuing = Block(CallStmt(Call("a_statement")), BreakIf(true));
+    auto* l = Loop(body, continuing);
+
+    Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+         utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+    GeneratorImpl& gen = Build();
+
+    gen.increment_indent();
+
+    ASSERT_TRUE(gen.EmitStatement(l)) << gen.error();
+    EXPECT_EQ(gen.result(), R"(  while (true) {
+    discard_fragment();
+    {
+      a_statement();
+      if (true) { break; }
+    }
+  }
+)");
+}
+
 TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) {
     Func("a_statement", {}, ty.void_(), utils::Empty);
 
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index c93a448..91b16b4 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -455,6 +455,19 @@
     return true;
 }
 
+bool Builder::GenerateBreakIfStatement(const ast::BreakIfStatement* stmt) {
+    TINT_ASSERT(Writer, !backedge_stack_.empty());
+    const auto cond_id = GenerateExpressionWithLoadIfNeeded(stmt->condition);
+    if (!cond_id) {
+        return false;
+    }
+    const ContinuingInfo& ci = continuing_stack_.back();
+    backedge_stack_.back() =
+        Backedge(spv::Op::OpBranchConditional,
+                 {Operand(cond_id), Operand(ci.break_target_id), Operand(ci.loop_header_id)});
+    return true;
+}
+
 bool Builder::GenerateContinueStatement(const ast::ContinueStatement*) {
     if (continue_stack_.empty()) {
         error_ = "Attempted to continue without a continue block";
@@ -3400,6 +3413,8 @@
         //  continuing { ...
         //    if (cond) {} else {break;}
         //  }
+        //
+        // TODO(crbug.com/tint/1451): Remove this when the if break construct is made an error.
         auto is_just_a_break = [](const ast::BlockStatement* block) {
             return block && (block->statements.Length() == 1) &&
                    block->Last()->Is<ast::BreakStatement>();
@@ -3643,6 +3658,7 @@
         stmt, [&](const ast::AssignmentStatement* a) { return GenerateAssignStatement(a); },
         [&](const ast::BlockStatement* b) { return GenerateBlockStatement(b); },
         [&](const ast::BreakStatement* b) { return GenerateBreakStatement(b); },
+        [&](const ast::BreakIfStatement* b) { return GenerateBreakIfStatement(b); },
         [&](const ast::CallStatement* c) { return GenerateCallExpression(c->expr) != 0; },
         [&](const ast::ContinueStatement* c) { return GenerateContinueStatement(c); },
         [&](const ast::DiscardStatement* d) { return GenerateDiscardStatement(d); },
@@ -3659,7 +3675,7 @@
             return true;  // Not emitted
         },
         [&](Default) {
-            error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
+            error_ = "unknown statement type: " + std::string(stmt->TypeInfo().name);
             return false;
         });
 }
diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h
index 9e633f1..1e412fa 100644
--- a/src/tint/writer/spirv/builder.h
+++ b/src/tint/writer/spirv/builder.h
@@ -248,6 +248,10 @@
     /// @param stmt the statement to generate
     /// @returns true if the statement was successfully generated
     bool GenerateBreakStatement(const ast::BreakStatement* stmt);
+    /// Generates a break-if statement
+    /// @param stmt the statement to generate
+    /// @returns true if the statement was successfully generated
+    bool GenerateBreakIfStatement(const ast::BreakIfStatement* stmt);
     /// Generates a continue statement
     /// @param stmt the statement to generate
     /// @returns true if the statement was successfully generated
diff --git a/src/tint/writer/spirv/builder_loop_test.cc b/src/tint/writer/spirv/builder_loop_test.cc
index 3ab27ef..e27cac4 100644
--- a/src/tint/writer/spirv/builder_loop_test.cc
+++ b/src/tint/writer/spirv/builder_loop_test.cc
@@ -234,12 +234,11 @@
 TEST_F(BuilderTest, Loop_WithContinuing_BreakIf) {
     // loop {
     //   continuing {
-    //     if (true) { break; }
+    //     break if (true);
     //   }
     // }
 
-    auto* if_stmt = If(Expr(true), Block(Break()));
-    auto* continuing = Block(if_stmt);
+    auto* continuing = Block(BreakIf(true));
     auto* loop = Loop(Block(), continuing);
     WrapInFunction(loop);
 
@@ -267,11 +266,10 @@
 TEST_F(BuilderTest, Loop_WithContinuing_BreakUnless) {
     // loop {
     //   continuing {
-    //     if (true) {} else { break; }
+    //     break if (false);
     //   }
     // }
-    auto* if_stmt = If(Expr(true), Block(), Else(Block(Break())));
-    auto* continuing = Block(if_stmt);
+    auto* continuing = Block(BreakIf(false));
     auto* loop = Loop(Block(), continuing);
     WrapInFunction(loop);
 
@@ -281,7 +279,7 @@
 
     EXPECT_TRUE(b.GenerateLoopStatement(loop)) << b.error();
     EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
-%6 = OpConstantTrue %5
+%6 = OpConstantNull %5
 )");
     EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
               R"(OpBranch %1
@@ -291,7 +289,7 @@
 %4 = OpLabel
 OpBranch %3
 %3 = OpLabel
-OpBranchConditional %6 %1 %2
+OpBranchConditional %6 %2 %1
 %2 = OpLabel
 )");
 }
@@ -300,13 +298,12 @@
     // loop {
     //   continuing {
     //     var cond = true;
-    //     if (cond) { break; }
+    //     break if (cond);
     //   }
     // }
 
     auto* cond_var = Decl(Var("cond", Expr(true)));
-    auto* if_stmt = If(Expr("cond"), Block(Break()));
-    auto* continuing = Block(cond_var, if_stmt);
+    auto* continuing = Block(cond_var, BreakIf("cond"));
     auto* loop = Loop(Block(), continuing);
     WrapInFunction(loop);
 
@@ -379,19 +376,17 @@
     //   continuing {
     //     loop {
     //       continuing {
-    //         if (true) { break; }
+    //         break if (true);
     //       }
     //     }
-    //     if (true) { break; }
+    //     break if (true);
     //   }
     // }
 
-    auto* inner_if_stmt = If(Expr(true), Block(Break()));
-    auto* inner_continuing = Block(inner_if_stmt);
+    auto* inner_continuing = Block(BreakIf(true));
     auto* inner_loop = Loop(Block(), inner_continuing);
 
-    auto* outer_if_stmt = If(Expr(true), Block(Break()));
-    auto* outer_continuing = Block(inner_loop, outer_if_stmt);
+    auto* outer_continuing = Block(inner_loop, BreakIf(true));
     auto* outer_loop = Loop(Block(), outer_continuing);
 
     WrapInFunction(outer_loop);
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index c62fbf5..bf7ebd7 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -958,6 +958,7 @@
         [&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
         [&](const ast::BlockStatement* b) { return EmitBlock(b); },
         [&](const ast::BreakStatement* b) { return EmitBreak(b); },
+        [&](const ast::BreakIfStatement* b) { return EmitBreakIf(b); },
         [&](const ast::CallStatement* c) {
             auto out = line();
             if (!EmitCall(out, c->expr)) {
@@ -1023,6 +1024,17 @@
     return true;
 }
 
+bool GeneratorImpl::EmitBreakIf(const ast::BreakIfStatement* b) {
+    auto out = line();
+
+    out << "break if ";
+    if (!EmitExpression(out, b->condition)) {
+        return false;
+    }
+    out << ";";
+    return true;
+}
+
 bool GeneratorImpl::EmitCase(const ast::CaseStatement* stmt) {
     if (stmt->selectors.Length() == 1 && stmt->ContainsDefault()) {
         line() << "default: {";
diff --git a/src/tint/writer/wgsl/generator_impl.h b/src/tint/writer/wgsl/generator_impl.h
index f4fc467..2b7f1c9 100644
--- a/src/tint/writer/wgsl/generator_impl.h
+++ b/src/tint/writer/wgsl/generator_impl.h
@@ -20,6 +20,7 @@
 #include "src/tint/ast/assignment_statement.h"
 #include "src/tint/ast/binary_expression.h"
 #include "src/tint/ast/bitcast_expression.h"
+#include "src/tint/ast/break_if_statement.h"
 #include "src/tint/ast/break_statement.h"
 #include "src/tint/ast/compound_assignment_statement.h"
 #include "src/tint/ast/continue_statement.h"
@@ -92,6 +93,10 @@
     /// @param stmt the statement to emit
     /// @returns true if the statement was emitted successfully
     bool EmitBreak(const ast::BreakStatement* stmt);
+    /// Handles a break-if statement
+    /// @param stmt the statement to emit
+    /// @returns true if the statement was emitted successfully
+    bool EmitBreakIf(const ast::BreakIfStatement* stmt);
     /// Handles generating a call expression
     /// @param out the output of the expression stream
     /// @param expr the call expression
diff --git a/src/tint/writer/wgsl/generator_impl_loop_test.cc b/src/tint/writer/wgsl/generator_impl_loop_test.cc
index bf0eef6..48510bb 100644
--- a/src/tint/writer/wgsl/generator_impl_loop_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_loop_test.cc
@@ -65,6 +65,32 @@
 )");
 }
 
+TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing_BreakIf) {
+    Func("a_statement", {}, ty.void_(), {});
+
+    auto* body = Block(create<ast::DiscardStatement>());
+    auto* continuing = Block(CallStmt(Call("a_statement")), BreakIf(true));
+    auto* l = Loop(body, continuing);
+
+    Func("F", utils::Empty, ty.void_(), utils::Vector{l},
+         utils::Vector{Stage(ast::PipelineStage::kFragment)});
+
+    GeneratorImpl& gen = Build();
+
+    gen.increment_indent();
+
+    ASSERT_TRUE(gen.EmitStatement(l)) << gen.error();
+    EXPECT_EQ(gen.result(), R"(  loop {
+    discard;
+
+    continuing {
+      a_statement();
+      break if true;
+    }
+  }
+)");
+}
+
 TEST_F(WgslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) {
     // var<workgroup> a : atomic<i32>;
     // for({ignore(1i); ignore(2i);}; ; ) {