spriv-reader: handle loop-header with internal divergence

Handle the case where the OpBranchConditional in a loop header
branches to two distinct blocks inside the loop construct.
This is an if-selection in disguise.

Create an kIfSelection with the same set of blocks as the kLoop,
and with the continue target as the merge.

Fixed: tint:524
Change-Id: I5150d19a2b4388da409e2da6e68ffafdc5d21a9a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47560
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: Alan Baker <alanbaker@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 9c5b8b4..4b25b05 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -1325,6 +1325,22 @@
           // in the block order, starting at the header, until just
           // before the continue target.
           top = push_construct(depth, Construct::kLoop, header, ct);
+
+          // If the loop header branches to two different blocks inside the loop
+          // construct, then the loop body should be modeled as an if-selection
+          // construct
+          std::vector<uint32_t> targets;
+          header_info->basic_block->ForEachSuccessorLabel(
+              [&targets](const uint32_t target) { targets.push_back(target); });
+          if ((targets.size() == 2u) && targets[0] != targets[1]) {
+            const auto target0_pos = GetBlockInfo(targets[0])->pos;
+            const auto target1_pos = GetBlockInfo(targets[1])->pos;
+            if (top->ContainsPos(target0_pos) &&
+                top->ContainsPos(target1_pos)) {
+              // Insert a synthetic if-selection
+              top = push_construct(depth+1, Construct::kIfSelection, header, ct);
+            }
+          }
         }
       } else {
         // From the interval rule, the selection construct consists of blocks
@@ -1705,7 +1721,7 @@
         if ((edge_kind == EdgeKind::kForward) ||
             (edge_kind == EdgeKind::kCaseFallThrough)) {
           // Check for an invalid forward exit out of this construct.
-          if (dest_info->pos >= src_construct.end_pos) {
+          if (dest_info->pos > src_construct.end_pos) {
             // In most cases we're bypassing the merge block for the source
             // construct.
             auto end_block = src_construct.end_id;
@@ -2151,10 +2167,12 @@
   // What constructs can we have entered?
   // - It can't be kFunction, because there is only one of those, and it was
   //   already on the stack at the outermost level.
-  // - We have at most one of kIfSelection, kSwitchSelection, or kLoop because
-  //   each of those is headed by a block with a merge instruction (OpLoopMerge
-  //   for kLoop, and OpSelectionMerge for the others), and the kIfSelection and
-  //   kSwitchSelection header blocks end in different branch instructions.
+  // - We have at most one of kSwitchSelection, or kLoop because each of those
+  //   is headed by a block with a merge instruction (OpLoopMerge for kLoop,
+  //   and OpSelectionMerge for kSwitchSelection).
+  // - When there is a kIfSelection, it can't contain another construct,
+  //   because both would have to have their own distinct merge instructions
+  //   and distinct terminators.
   // - A kContinue can contain a kContinue
   //   This is possible in Vulkan SPIR-V, but Tint disallows this by the rule
   //   that a block can be continue target for at most one header block. See
@@ -2162,8 +2180,14 @@
   //   then by a dominance argument, the inner loop continue target can only be
   //   a single-block loop.
   // TODO(dneto): Handle this case.
-  // - All that's left is a kContinue and one of kIfSelection, kSwitchSelection,
-  //   kLoop.
+  // - If a kLoop is on the outside, its terminator is either:
+  //   - an OpBranch, in which case there is no other construct.
+  //   - an OpBranchConditional, in which case there is either an kIfSelection
+  //     (when both branch targets are different and are inside the loop),
+  //     or no other construct (because the branch targets are the same,
+  //     or one of them is a break or continue).
+  // - All that's left is a kContinue on the outside, and one of
+  //   kIfSelection, kSwitchSelection, kLoop on the inside.
   //
   //   The kContinue can be the parent of the other.  For example, a selection
   //   starting at the first block of a continue construct.
@@ -2189,19 +2213,20 @@
   //
   // So we fall into one of the following cases:
   //  - We are entering 0 or 1 constructs, or
-  //  - We are entering 2 constructs, with the outer one being a kContinue, the
-  //    inner one is not a continue.
+  //  - We are entering 2 constructs, with the outer one being a kContinue or
+  //    kLoop, the inner one is not a continue.
   if (entering_constructs.size() > 2) {
     return Fail() << "internal error: bad construct nesting found";
   }
   if (entering_constructs.size() == 2) {
     auto inner_kind = entering_constructs[0]->kind;
     auto outer_kind = entering_constructs[1]->kind;
-    if (outer_kind != Construct::kContinue) {
-      return Fail() << "internal error: bad construct nesting. Only Continue "
-                       "construct can be outer construct on same block.  Got "
-                       "outer kind "
-                    << int(outer_kind) << " inner kind " << int(inner_kind);
+    if (outer_kind != Construct::kContinue && outer_kind != Construct::kLoop) {
+      return Fail()
+             << "internal error: bad construct nesting. Only a Continue "
+                "or a Loop construct can be outer construct on same block.  "
+                "Got outer kind "
+             << int(outer_kind) << " inner kind " << int(inner_kind);
     }
     if (inner_kind == Construct::kContinue) {
       return Fail() << "internal error: unsupported construct nesting: "
diff --git a/src/reader/spirv/function_cfg_test.cc b/src/reader/spirv/function_cfg_test.cc
index 77b5254..af99038 100644
--- a/src/reader/spirv/function_cfg_test.cc
+++ b/src/reader/spirv/function_cfg_test.cc
@@ -70,10 +70,13 @@
     %uint_6 = OpConstant %uint 6
     %uint_7 = OpConstant %uint 7
     %uint_8 = OpConstant %uint 8
+    %uint_10 = OpConstant %uint 10
     %uint_20 = OpConstant %uint 20
     %uint_30 = OpConstant %uint 30
     %uint_40 = OpConstant %uint 40
     %uint_50 = OpConstant %uint 50
+    %uint_90 = OpConstant %uint 90
+    %uint_99 = OpConstant %uint 99
 
     %ptr_Private_uint = OpTypePointer Private %uint
     %var = OpVariable %ptr_Private_uint Private
@@ -3672,6 +3675,55 @@
   EXPECT_EQ(fe.GetBlockInfo(99)->construct, constructs[0].get());
 }
 
+TEST_F(SpvParserTest, LabelControlFlowConstructs_LoopInterallyDiverge) {
+  // In this case, insert a synthetic if-selection with the same blocks
+  // as the loop construct.
+  // crbug.com/tint/524
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+
+     %10 = OpLabel
+     OpBranch %20
+
+     %20 = OpLabel
+     OpLoopMerge %99 %90 None
+     OpBranchConditional %cond %30 %40 ; divergence to distinct targets in the body
+
+       %30 = OpLabel
+       OpBranch %90
+
+       %40 = OpLabel
+       OpBranch %90
+
+     %90 = OpLabel ; continue target
+     OpBranch %20
+
+     %99 = OpLabel ; loop merge
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  ASSERT_TRUE(FlowLabelControlFlowConstructs(&fe)) << p->error();
+  const auto& constructs = fe.constructs();
+  EXPECT_EQ(constructs.size(), 4u);
+  ASSERT_THAT(ToString(constructs), Eq(R"(ConstructList{
+  Construct{ Function [0,6) begin_id:10 end_id:0 depth:0 parent:null }
+  Construct{ Continue [4,5) begin_id:90 end_id:99 depth:1 parent:Function@10 in-c:Continue@90 }
+  Construct{ Loop [1,4) begin_id:20 end_id:90 depth:1 parent:Function@10 scope:[1,5) in-l:Loop@20 }
+  Construct{ IfSelection [1,4) begin_id:20 end_id:90 depth:2 parent:Loop@20 in-l:Loop@20 }
+})")) << constructs;
+  // The block records the nearest enclosing construct.
+  EXPECT_EQ(fe.GetBlockInfo(10)->construct, constructs[0].get());
+  EXPECT_EQ(fe.GetBlockInfo(20)->construct, constructs[3].get());
+  EXPECT_EQ(fe.GetBlockInfo(30)->construct, constructs[3].get());
+  EXPECT_EQ(fe.GetBlockInfo(40)->construct, constructs[3].get());
+  EXPECT_EQ(fe.GetBlockInfo(90)->construct, constructs[1].get());
+  EXPECT_EQ(fe.GetBlockInfo(99)->construct, constructs[0].get());
+}
+
 TEST_F(SpvParserTest, FindSwitchCaseHeaders_DefaultIsLongRangeBackedge) {
   auto assembly = CommonTypes() + R"(
      %100 = OpFunction %void None %voidfn
@@ -14259,6 +14311,87 @@
   ASSERT_EQ(expect, got);
 }
 
+TEST_F(SpvParserTest, EmitBody_LoopInternallyDiverge_Simple) {
+  // crbug.com/tint/524
+  auto assembly = CommonTypes() + R"(
+     %100 = OpFunction %void None %voidfn
+     %10 = OpLabel
+     OpStore %var %uint_10
+     OpBranch %20
+
+     %20 = OpLabel
+     OpStore %var %uint_20
+     OpLoopMerge %99 %90 None
+     OpBranchConditional %cond %30 %40 ; divergence
+
+       %30 = OpLabel
+       OpStore %var %uint_30
+       OpBranch %90
+
+       %40 = OpLabel
+       OpStore %var %uint_40
+       OpBranch %90
+
+     %90 = OpLabel ; continue target
+     OpStore %var %uint_90
+     OpBranch %20
+
+     %99 = OpLabel ; loop merge
+     OpStore %var %uint_99
+     OpReturn
+
+     OpFunctionEnd
+)";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+  FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  auto got = ToString(p->builder(), fe.ast_body());
+  auto* expect = R"(Assignment{
+  Identifier[not set]{var_1}
+  ScalarConstructor[not set]{10u}
+}
+Loop{
+  Assignment{
+    Identifier[not set]{var_1}
+    ScalarConstructor[not set]{20u}
+  }
+  If{
+    (
+      ScalarConstructor[not set]{false}
+    )
+    {
+      Assignment{
+        Identifier[not set]{var_1}
+        ScalarConstructor[not set]{30u}
+      }
+      Continue{}
+    }
+  }
+  Else{
+    {
+      Assignment{
+        Identifier[not set]{var_1}
+        ScalarConstructor[not set]{40u}
+      }
+    }
+  }
+  continuing {
+    Assignment{
+      Identifier[not set]{var_1}
+      ScalarConstructor[not set]{90u}
+    }
+  }
+}
+Assignment{
+  Identifier[not set]{var_1}
+  ScalarConstructor[not set]{99u}
+}
+Return{}
+)";
+  ASSERT_EQ(expect, got) << got;
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader