spirv-reader: handle break and continue from if-selection header

Fixed: tint:243, tint:494
Change-Id: I6baf3360b44042b52f510b8f761376e1daab878f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47540
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Alan Baker <alanbaker@google.com>
Reviewed-by: Alan Baker <alanbaker@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 7ce03ed..d31e56d 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -1799,13 +1799,15 @@
 
     // The cases for each edge are:
     //  - kBack: invalid because it's an invalid exit from the selection
-    //  - kSwitchBreak
-    //  - kLoopBreak
-    //  - kLoopContinue
+    //  - kSwitchBreak ; record this for later special processing
+    //  - kLoopBreak ; record this for later special processing
+    //  - kLoopContinue ; record this for later special processing
     //  - kIfBreak; normal case, may require a guard variable.
     //  - kFallThrough; invalid exit from the selection
     //  - kForward; normal case
 
+    if_header_info->true_kind = if_header_info->succ_edge[true_head];
+    if_header_info->false_kind = if_header_info->succ_edge[false_head];
     if (contains_true) {
       if_header_info->true_head = true_head;
     }
@@ -2344,9 +2346,19 @@
   //   ends at the premerge head (if it exists) or at the selection end.
   const uint32_t else_end = premerge_head ? premerge_head : intended_merge;
 
+  const bool true_is_break = (block_info.true_kind == EdgeKind::kSwitchBreak) ||
+                             (block_info.true_kind == EdgeKind::kLoopBreak);
+  const bool false_is_break =
+      (block_info.false_kind == EdgeKind::kSwitchBreak) ||
+      (block_info.false_kind == EdgeKind::kLoopBreak);
+  const bool true_is_continue = block_info.true_kind == EdgeKind::kLoopContinue;
+  const bool false_is_continue =
+      block_info.false_kind == EdgeKind::kLoopContinue;
+
   // Push statement blocks for the then-clause and the else-clause.
   // But make sure we do it in the right order.
-  auto push_else = [this, builder, else_end, construct]() {
+  auto push_else = [this, builder, else_end, construct, false_is_break,
+                    false_is_continue]() {
     // Push the else clause onto the stack first.
     PushNewStatementBlock(
         construct, else_end, [=](const ast::StatementList& stmts) {
@@ -2359,9 +2371,16 @@
                 create<ast::ElseStatement>(Source{}, nullptr, else_body));
           }
         });
+    if (false_is_break) {
+      AddStatement(create<ast::BreakStatement>(Source{}));
+    }
+    if (false_is_continue) {
+      AddStatement(create<ast::ContinueStatement>(Source{}));
+    }
   };
 
-  if (GetBlockInfo(else_end)->pos < GetBlockInfo(then_end)->pos) {
+  if (!true_is_break && !true_is_continue &&
+      (GetBlockInfo(else_end)->pos < GetBlockInfo(then_end)->pos)) {
     // Process the else-clause first.  The then-clause will be empty so avoid
     // pushing onto the stack at all.
     push_else();
@@ -2382,7 +2401,7 @@
         // just like in the original SPIR-V.
         PushTrueGuard(construct->end_id);
       } else {
-        // Add a flow guard around the blocks in the premrege area.
+        // Add a flow guard around the blocks in the premege area.
         PushGuard(guard_name, construct->end_id);
       }
     }
@@ -2399,6 +2418,12 @@
         construct, then_end, [=](const ast::StatementList& stmts) {
           builder->body = create<ast::BlockStatement>(Source{}, stmts);
         });
+    if (true_is_break) {
+      AddStatement(create<ast::BreakStatement>(Source{}));
+    }
+    if (true_is_continue) {
+      AddStatement(create<ast::ContinueStatement>(Source{}));
+    }
   }
 
   return success();
diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h
index 228df47..7454054 100644
--- a/src/reader/spirv/function.h
+++ b/src/reader/spirv/function.h
@@ -128,6 +128,12 @@
   /// The following fields record relationships among blocks in a selection
   /// construct for an OpBranchConditional instruction.
 
+  /// When this block is an if-selection header, this is the edge kind
+  /// for the true branch.
+  EdgeKind true_kind = EdgeKind::kForward;
+  /// When this block is an if-selection header, this is the edge kind
+  /// for the false branch.
+  EdgeKind false_kind = EdgeKind::kForward;
   /// If not 0, then this block is an if-selection header, and `true_head` is
   /// the target id of the true branch on the OpBranchConditional, and that
   /// target is inside the if-selection.
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
index 05a1fff..90f3350 100644
--- a/src/reader/spirv/function_cfg_test.cc
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -13955,6 +13955,310 @@
                  "parent:Function@10 scope:[1,3) in-l:Loop@20 }"));
 }
 
+TEST_F(SpvParserTest, EmitBody_IfSelection_TrueBranch_LoopBreak) {
+  // crbug.com/tint/243
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %5 = OpLabel
+     OpBranch %10
+
+     %10 = OpLabel
+     OpLoopMerge %99 %90 None
+     OpBranch %20
+
+     %20 = OpLabel
+     OpSelectionMerge %40 None
+     OpBranchConditional %cond %99 %30 ; true branch breaking is ok
+
+     %30 = OpLabel
+     OpBranch %40
+
+     %40 = OpLabel ; selection merge
+     OpBranch %90
+
+     %90 = OpLabel ; continue target
+     OpBranch %10 ; backedge
+
+     %99 = OpLabel ; loop merge
+     OpReturn
+     OpFunctionEnd
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+
+  auto got = ToString(p->builder(), fe.ast_body());
+  auto* expect = R"(Loop{
+  If{
+    (
+      ScalarConstructor[not set]{false}
+    )
+    {
+      Break{}
+    }
+  }
+}
+Return{}
+)";
+  ASSERT_EQ(expect, got);
+}
+
+TEST_F(SpvParserTest, EmitBody_TrueBranch_LoopContinue) {
+  // crbug.com/tint/243
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %5 = OpLabel
+     OpBranch %10
+
+     %10 = OpLabel
+     OpLoopMerge %99 %90 None
+     OpBranch %20
+
+     %20 = OpLabel
+     OpSelectionMerge %40 None
+     OpBranchConditional %cond %90 %30 ; true branch continue is ok
+
+     %30 = OpLabel
+     OpBranch %40
+
+     %40 = OpLabel ; selection merge
+     OpBranch %90
+
+     %90 = OpLabel ; continue target
+     OpBranch %10 ; backedge
+
+     %99 = OpLabel ; loop merge
+     OpReturn
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  auto got = ToString(p->builder(), fe.ast_body());
+  auto* expect = R"(Loop{
+  If{
+    (
+      ScalarConstructor[not set]{false}
+    )
+    {
+      Continue{}
+    }
+  }
+}
+Return{}
+)";
+  ASSERT_EQ(expect, got);
+}
+
+TEST_F(SpvParserTest, EmitBody_TrueBranch_SwitchBreak) {
+  // crbug.com/tint/243
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %uint_20 %99 20 %20
+
+     %20 = OpLabel
+     OpSelectionMerge %40 None
+     OpBranchConditional %cond %99 %30 ; true branch switch break is ok
+
+     %30 = OpLabel
+     OpBranch %40
+
+     %40 = OpLabel ; if-selection merge
+     OpBranch %99
+
+     %99 = OpLabel ; switch merge
+     OpReturn
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  auto got = ToString(p->builder(), fe.ast_body());
+  auto* expect = R"(Switch{
+  ScalarConstructor[not set]{20}
+  {
+    Case 20{
+      If{
+        (
+          ScalarConstructor[not set]{false}
+        )
+        {
+          Break{}
+        }
+      }
+    }
+    Default{
+    }
+  }
+}
+Return{}
+)";
+  ASSERT_EQ(expect, got);
+}
+
+TEST_F(SpvParserTest, EmitBody_FalseBranch_LoopBreak) {
+  // crbug.com/tint/243
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %5 = OpLabel
+     OpBranch %10
+
+     %10 = OpLabel
+     OpLoopMerge %99 %90 None
+     OpBranch %20
+
+     %20 = OpLabel
+     OpSelectionMerge %40 None
+     OpBranchConditional %cond %30 %99 ; false branch breaking is ok
+
+     %30 = OpLabel
+     OpBranch %40
+
+     %40 = OpLabel ; selection merge
+     OpBranch %90
+
+     %90 = OpLabel ; continue target
+     OpBranch %10 ; backedge
+
+     %99 = OpLabel ; loop merge
+     OpReturn
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  auto got = ToString(p->builder(), fe.ast_body());
+  auto* expect = R"(Loop{
+  If{
+    (
+      ScalarConstructor[not set]{false}
+    )
+    {
+    }
+  }
+  Else{
+    {
+      Break{}
+    }
+  }
+}
+Return{}
+)";
+  ASSERT_EQ(expect, got);
+}
+
+TEST_F(SpvParserTest, EmitBody_FalseBranch_LoopContinue) {
+  // crbug.com/tint/243
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %5 = OpLabel
+     OpBranch %10
+
+     %10 = OpLabel
+     OpLoopMerge %99 %90 None
+     OpBranch %20
+
+     %20 = OpLabel
+     OpSelectionMerge %40 None
+     OpBranchConditional %cond %30 %90 ; false branch continue is ok
+
+     %30 = OpLabel
+     OpBranch %40
+
+     %40 = OpLabel ; selection merge
+     OpBranch %90
+
+     %90 = OpLabel ; continue target
+     OpBranch %10 ; backedge
+
+     %99 = OpLabel ; loop merge
+     OpReturn
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  auto got = ToString(p->builder(), fe.ast_body());
+  auto* expect = R"(Loop{
+  If{
+    (
+      ScalarConstructor[not set]{false}
+    )
+    {
+    }
+  }
+  Else{
+    {
+      Continue{}
+    }
+  }
+}
+Return{}
+)";
+  ASSERT_EQ(expect, got) << p->error();
+}
+
+TEST_F(SpvParserTest, EmitBody_FalseBranch_SwitchBreak) {
+  // crbug.com/tint/243
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpSelectionMerge %99 None
+     OpSwitch %uint_20 %99 20 %20
+
+     %20 = OpLabel
+     OpSelectionMerge %40 None
+     OpBranchConditional %cond %30 %99 ; false branch switch break is ok
+
+     %30 = OpLabel
+     OpBranch %40
+
+     %40 = OpLabel ; if-selection merge
+     OpBranch %99
+
+     %99 = OpLabel ; switch merge
+     OpReturn
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  auto got = ToString(p->builder(), fe.ast_body());
+  auto* expect = R"(Switch{
+  ScalarConstructor[not set]{20}
+  {
+    Case 20{
+      If{
+        (
+          ScalarConstructor[not set]{false}
+        )
+        {
+        }
+      }
+      Else{
+        {
+          Break{}
+        }
+      }
+    }
+    Default{
+    }
+  }
+}
+Return{}
+)";
+  ASSERT_EQ(expect, got);
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader
diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc
index 1543518..cb70b42 100644
--- a/src/reader/spirv/function_var_test.cc
+++ b/src/reader/spirv/function_var_test.cc
@@ -1963,6 +1963,11 @@
       }
     }
   }
+  Else{
+    {
+      Continue{}
+    }
+  }
   continuing {
     VariableDeclStatement{
       VariableConst{