spirv-writer: Generate load if needed for continue block conditional exit

Fixed: tint:1343
Change-Id: Ic105e407c572f1c309da8f21908a16c08b081f7f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/72641
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 5b125b4..a0eba5a 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -1183,6 +1183,13 @@
   return val;
 }
 
+uint32_t Builder::GenerateNonReferenceExpression(const ast::Expression* expr) {
+  if (const auto id = GenerateExpression(expr)) {
+    return GenerateLoadIfNeeded(TypeOf(expr), id);
+  }
+  return 0;
+}
+
 uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) {
   if (auto* ref = type->As<sem::Reference>()) {
     type = ref->StoreType();
@@ -3578,7 +3585,10 @@
     if (is_just_a_break(stmt->body) && stmt->else_statements.empty()) {
       // It's a break-if.
       TINT_ASSERT(Writer, !backedge_stack_.empty());
-      const auto cond_id = GenerateExpression(stmt->condition);
+      const auto cond_id = GenerateNonReferenceExpression(stmt->condition);
+      if (!cond_id) {
+        return false;
+      }
       backedge_stack_.back() =
           Backedge(spv::Op::OpBranchConditional,
                    {Operand::Int(cond_id), Operand::Int(ci.break_target_id),
@@ -3590,7 +3600,10 @@
           is_just_a_break(es.back()->body)) {
         // It's a break-unless.
         TINT_ASSERT(Writer, !backedge_stack_.empty());
-        const auto cond_id = GenerateExpression(stmt->condition);
+        const auto cond_id = GenerateNonReferenceExpression(stmt->condition);
+        if (!cond_id) {
+          return false;
+        }
         backedge_stack_.back() =
             Backedge(spv::Op::OpBranchConditional,
                      {Operand::Int(cond_id), Operand::Int(ci.loop_header_id),
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index d85e988..1b43ee4 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -454,9 +454,17 @@
   /// @param stmt the statement to generate
   /// @returns true if the statement was generated
   bool GenerateStatement(const ast::Statement* stmt);
-  /// Geneates an OpLoad
-  /// @param type the type to load
-  /// @param id the variable id to load
+  /// Generates an expression. If the WGSL expression does not have reference
+  /// type, then return the SPIR-V ID for the expression. Otherwise implement
+  /// the WGSL Load Rule: generate an OpLoad and return the ID of the result.
+  /// Returns 0 if the expression could not be generated.
+  /// @param expr the expression to be generate
+  /// @returns the the ID of the expression, or loaded expression
+  uint32_t GenerateNonReferenceExpression(const ast::Expression* expr);
+  /// Generates an OpLoad on the given ID if it has reference type in WGSL,
+  /// othewrise return the ID itself.
+  /// @param type the type of the expression
+  /// @param id the SPIR-V id of the experssion
   /// @returns the ID of the loaded value or `id` if type is not a reference
   uint32_t GenerateLoadIfNeeded(const sem::Type* type, uint32_t id);
   /// Generates an OpStore. Emits an error and returns false if we're
diff --git a/src/writer/spirv/builder_loop_test.cc b/src/writer/spirv/builder_loop_test.cc
index 56d0aac..a074461 100644
--- a/src/writer/spirv/builder_loop_test.cc
+++ b/src/writer/spirv/builder_loop_test.cc
@@ -287,6 +287,84 @@
 )");
 }
 
+TEST_F(BuilderTest, Loop_WithContinuing_BreakIf_ConditionIsVar) {
+  // loop {
+  //   continuing {
+  //     var cond = true;
+  //     if (cond) { break; }
+  //   }
+  // }
+
+  auto* cond_var = Decl(Var("cond", nullptr, Expr(true)));
+  auto* if_stmt = If(Expr("cond"), Block(Break()), ast::ElseStatementList{});
+  auto* continuing = Block(cond_var, if_stmt);
+  auto* loop = Loop(Block(), continuing);
+  WrapInFunction(loop);
+
+  spirv::Builder& b = Build();
+
+  b.push_function(Function{});
+
+  EXPECT_TRUE(b.GenerateLoopStatement(loop)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
+%6 = OpConstantTrue %5
+%8 = OpTypePointer Function %5
+%9 = OpConstantNull %5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpBranch %1
+%1 = OpLabel
+OpLoopMerge %2 %3 None
+OpBranch %4
+%4 = OpLabel
+OpBranch %3
+%3 = OpLabel
+OpStore %7 %6
+%10 = OpLoad %5 %7
+OpBranchConditional %10 %2 %1
+%2 = OpLabel
+)");
+}
+
+TEST_F(BuilderTest, Loop_WithContinuing_BreakUnless_ConditionIsVar) {
+  // loop {
+  //   continuing {
+  //     var cond = true;
+  //     if (cond) {} else { break; }
+  //   }
+  // }
+  auto* cond_var = Decl(Var("cond", nullptr, Expr(true)));
+  auto* if_stmt = If(Expr("cond"), Block(),
+                     ast::ElseStatementList{Else(nullptr, Block(Break()))});
+  auto* continuing = Block(cond_var, if_stmt);
+  auto* loop = Loop(Block(), continuing);
+  WrapInFunction(loop);
+
+  spirv::Builder& b = Build();
+
+  b.push_function(Function{});
+
+  EXPECT_TRUE(b.GenerateLoopStatement(loop)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeBool
+%6 = OpConstantTrue %5
+%8 = OpTypePointer Function %5
+%9 = OpConstantNull %5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpBranch %1
+%1 = OpLabel
+OpLoopMerge %2 %3 None
+OpBranch %4
+%4 = OpLabel
+OpBranch %3
+%3 = OpLabel
+OpStore %7 %6
+%10 = OpLoad %5 %7
+OpBranchConditional %10 %1 %2
+%2 = OpLabel
+)");
+}
+
 TEST_F(BuilderTest, Loop_WithContinuing_BreakIf_Nested) {
   // Make sure the right backedge and break target are used.
   // loop {