[spirv-reader][ir] Fix incorrect switch/loop exits.

In the case of an `if` where one of the branches jumps to a switch/loop
merge, and the other branch goes to the if merge which also jumps to the
switch/loop merge, from an IR perspective those two SPIR-V blocks would
refer to the same IR block. In this case, we need to differentiate by
looking at the `BranchConditional` and determining if we need to handle
this special terminator decision.

Bug: 42250952
Change-Id: I19d289bc455d7714b62898afc88ea16bd013a6fa
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/251197
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 1f6a86c..41b7c6d 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -2178,6 +2178,7 @@
         auto* type = Type(inst.type_id());
 
         auto* phi_spirv_block = spirv_context_->get_instr_block(&inst);
+        auto phi_blk_id = phi_spirv_block->id();
 
         // The merge target (which is the OpPhi SPIR-V block) is a walk stop block
         auto* loop = StopWalkingAt(phi_spirv_block->id())->As<core::ir::Loop>();
@@ -2191,9 +2192,34 @@
             auto value_id = inst.GetSingleWordInOperand(i);
             auto blk_id = inst.GetSingleWordInOperand(i + 1);
 
-            auto value_blk_iter = spirv_id_to_block_.find(blk_id);
-            auto* value_ir_blk = value_blk_iter->second;
-            auto* term = value_ir_blk->Terminator();
+            core::ir::Terminator* term = nullptr;
+
+            // If the basic block ends in a branch conditional, then we really need to update the
+            // branch which goes to the current block, not the terminator of the block we emitted
+            // into. This difference is because the merge block will _also_ end up in the same IR
+            // block and we can't tell the difference if we just get the terminator.
+            const auto& bb = current_spirv_function_->FindBlock(blk_id);
+            auto* terminator = (*bb).terminator();
+            if (terminator->opcode() == spv::Op::OpBranchConditional) {
+                uint32_t true_id = terminator->GetSingleWordInOperand(1);
+                uint32_t false_id = terminator->GetSingleWordInOperand(2);
+
+                auto iter = branch_conditional_to_if_.find(terminator);
+                if (iter != branch_conditional_to_if_.end()) {
+                    auto* if_ = iter->second;
+                    if (true_id == phi_blk_id) {
+                        term = if_->True()->Terminator();
+                    } else if (false_id == phi_blk_id) {
+                        term = if_->False()->Terminator();
+                    }
+                }
+            }
+
+            if (term == nullptr) {
+                auto value_blk_iter = spirv_id_to_block_.find(blk_id);
+                auto* value_ir_blk = value_blk_iter->second;
+                term = value_ir_blk->Terminator();
+            }
 
             if (term->Is<core::ir::Unreachable>()) {
                 default_value = value_id;
@@ -2496,25 +2522,54 @@
         core::ir::Switch* ctrl = nullptr;
         std::optional<core::ir::Value*> value_for_default_block = std::nullopt;
 
+        auto phi_blk_id = spirv_context_->get_instr_block(&inst)->id();
+
         for (uint32_t i = 0; i < inst.NumInOperands(); i += 2) {
             auto value_id = inst.GetSingleWordInOperand(i);
             auto blk_id = inst.GetSingleWordInOperand(i + 1);
 
-            auto value_blk_iter = spirv_id_to_block_.find(blk_id);
-            auto* value_ir_blk = value_blk_iter->second;
+            core::ir::Terminator* term = nullptr;
 
-            // In the case of a switch, the block can refer to the header of the switch, in this
-            // case it means we don't have a default block and we jump over the switch itself, so we
-            // need to insert this value into the terminator of the default block of the switch.
-            //
-            // We store this away as we may not know the switch yet, we need to wait until we get
-            // the control instruction and do the work later.
-            if (blk_id == header_id) {
-                value_for_default_block = Value(value_id);
-                continue;
+            // If the basic block ends in a branch conditional, then we really need to update the
+            // branch which goes to the current block, not the terminator of the block we emitted
+            // into. This difference is because the merge block will _also_ end up in the same IR
+            // block and we can't tell the difference if we just get the terminator.
+            const auto& bb = current_spirv_function_->FindBlock(blk_id);
+            auto* terminator = (*bb).terminator();
+            if (terminator->opcode() == spv::Op::OpBranchConditional) {
+                uint32_t true_id = terminator->GetSingleWordInOperand(1);
+                uint32_t false_id = terminator->GetSingleWordInOperand(2);
+
+                auto iter = branch_conditional_to_if_.find(terminator);
+                TINT_ASSERT(iter != branch_conditional_to_if_.end());
+
+                auto* if_ = iter->second;
+                if (true_id == phi_blk_id) {
+                    term = if_->True()->Terminator();
+                } else if (false_id == phi_blk_id) {
+                    term = if_->False()->Terminator();
+                }
             }
 
-            auto* term = value_ir_blk->Terminator();
+            if (term == nullptr) {
+                auto value_blk_iter = spirv_id_to_block_.find(blk_id);
+                auto* value_ir_blk = value_blk_iter->second;
+
+                // In the case of a switch, the block can refer to the header of the switch, in this
+                // case it means we don't have a default block and we jump over the switch itself,
+                // so we need to insert this value into the terminator of the default block of the
+                // switch.
+                //
+                // We store this away as we may not know the switch yet, we need to wait until we
+                // get the control instruction and do the work later.
+                if (blk_id == header_id) {
+                    value_for_default_block = Value(value_id);
+                    continue;
+                }
+
+                term = value_ir_blk->Terminator();
+            }
+
             // Push a placeholder for the operand value at this point. We'll store away the
             // terminator/index pair along with the required value and then fill it in at the end of
             // the block emission. This is because a PHI can refer to a value which is defined after
@@ -3120,6 +3175,8 @@
         auto* if_ = b_.If(cond);
         EmitWithoutResult(if_);
 
+        branch_conditional_to_if_.insert({&inst, if_});
+
         std::optional<uint32_t> merge_id = std::nullopt;
 
         auto* merge_inst = bb.GetMergeInst();
@@ -3146,7 +3203,6 @@
             premerge_if_ = b_.If(b_.Constant(true));
             walk_stop_blocks_.insert({premerge_start_id.value(), premerge_if_});
         }
-
         if (auto* ctrl = StopWalkingAt(true_id)) {
             auto* new_inst = EmitBranchStopBlock(ctrl, if_, if_->True(), true_id);
             inst_to_spirv_block_[new_inst] = bb.id();
@@ -4066,6 +4122,9 @@
     // Map of `var` to the access mode it was originally created with. This may be different from
     // the current mode if we needed to set a default mode.
     std::unordered_map<core::ir::Var*, core::Access> var_to_original_access_mode_;
+
+    // Map of spir-v branch conditional instructions to the related IR if instruction.
+    std::unordered_map<const spvtools::opt::Instruction*, core::ir::If*> branch_conditional_to_if_;
 };
 
 }  // namespace
diff --git a/src/tint/lang/spirv/reader/parser/phi_test.cc b/src/tint/lang/spirv/reader/parser/phi_test.cc
index 1e07b99..865d62e 100644
--- a/src/tint/lang/spirv/reader/parser/phi_test.cc
+++ b/src/tint/lang/spirv/reader/parser/phi_test.cc
@@ -308,6 +308,307 @@
 )");
 }
 
+TEST_F(SpirvParserTest, Phi_Switch_FromIfBreak) {
+    EXPECT_IR(R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+       %bool = OpTypeBool
+      %int_0 = OpConstant %int 0
+      %int_1 = OpConstant %int 1
+      %int_2 = OpConstant %int 2
+      %int_3 = OpConstant %int 3
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %3
+          %4 = OpLabel
+               OpSelectionMerge %13 None
+               OpSwitch %int_0 %10 0 %11 1 %12
+         %10 = OpLabel
+               OpBranch %13
+         %11 = OpLabel
+               OpBranch %13
+         %12 = OpLabel
+               OpSelectionMerge %20 None
+               OpBranchConditional %true %20 %21
+         %21 = OpLabel
+               OpBranch %13
+         %20 = OpLabel
+               OpBranch %13
+         %13 = OpLabel
+         %14 = OpPhi %int %int_0 %10 %int_1 %11 %int_2 %20 %int_3 %21
+         %15 = OpIAdd %int %14 %14
+               OpReturn
+               OpFunctionEnd
+)",
+              R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B1: {
+    %2:i32 = switch 0i [c: (default, $B2), c: (0i, $B3), c: (1i, $B4)] {  # switch_1
+      $B2: {  # case
+        exit_switch 0i  # switch_1
+      }
+      $B3: {  # case
+        exit_switch 1i  # switch_1
+      }
+      $B4: {  # case
+        if true [t: $B5, f: $B6] {  # if_1
+          $B5: {  # true
+            exit_if  # if_1
+          }
+          $B6: {  # false
+            exit_switch 3i  # switch_1
+          }
+        }
+        exit_switch 2i  # switch_1
+      }
+    }
+    %3:i32 = spirv.add<i32> %2, %2
+    ret
+  }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Phi_Switch_FromIfBreak_InDefault) {
+    EXPECT_IR(R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpName %main "main"
+       %void = OpTypeVoid
+          %9 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+       %bool = OpTypeBool
+    %float_7 = OpConstant %float 7
+    %float_8 = OpConstant %float 8
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %9
+         %45 = OpLabel
+               OpSelectionMerge %52 None
+               OpSwitch %uint_0 %53
+         %53 = OpLabel
+               OpBranch %54
+         %54 = OpLabel
+               OpSelectionMerge %84 None
+               OpBranchConditional %true %52 %84
+         %84 = OpLabel
+               OpBranch %52
+         %52 = OpLabel
+         %85 = OpPhi %float %float_7 %54 %float_8 %84
+        %100 = OpFAdd %float %85 %85
+               OpReturn
+               OpFunctionEnd
+)",
+              R"(
+%main = @fragment func():void {
+  $B1: {
+    %2:f32 = switch 0u [c: (default, $B2)] {  # switch_1
+      $B2: {  # case
+        if true [t: $B3, f: $B4] {  # if_1
+          $B3: {  # true
+            exit_switch 7.0f  # switch_1
+          }
+          $B4: {  # false
+            exit_if  # if_1
+          }
+        }
+        exit_switch 8.0f  # switch_1
+      }
+    }
+    %3:f32 = add %2, %2
+    ret
+  }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Phi_Switch_FromIf_BothJumpToMerge_InDefault) {
+    EXPECT_IR(R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpName %main "main"
+       %void = OpTypeVoid
+          %9 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+       %bool = OpTypeBool
+    %float_7 = OpConstant %float 7
+    %float_8 = OpConstant %float 8
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %9
+         %45 = OpLabel
+               OpSelectionMerge %52 None
+               OpSwitch %uint_0 %53
+         %53 = OpLabel
+               OpBranch %54
+         %54 = OpLabel
+               OpSelectionMerge %84 None
+               OpBranchConditional %true %52 %52
+         %84 = OpLabel
+               OpBranch %52
+         %52 = OpLabel
+         %85 = OpPhi %float %float_7 %54 %float_8 %84
+        %100 = OpFAdd %float %85 %85
+               OpReturn
+               OpFunctionEnd
+)",
+              R"(
+%main = @fragment func():void {
+  $B1: {
+    %2:f32 = switch 0u [c: (default, $B2)] {  # switch_1
+      $B2: {  # case
+        %3:bool = or true, true
+        if %3 [t: $B3, f: $B4] {  # if_1
+          $B3: {  # true
+            exit_switch 7.0f  # switch_1
+          }
+          $B4: {  # false
+            unreachable
+          }
+        }
+        exit_switch 8.0f  # switch_1
+      }
+    }
+    %4:f32 = add %2, %2
+    ret
+  }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Phi_Switch_FromIfBreakBoth_InDefault) {
+    EXPECT_IR(R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpName %main "main"
+       %void = OpTypeVoid
+          %9 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+       %bool = OpTypeBool
+    %float_7 = OpConstant %float 7
+    %float_8 = OpConstant %float 8
+    %float_9 = OpConstant %float 9
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %9
+         %45 = OpLabel
+               OpSelectionMerge %52 None
+               OpSwitch %uint_0 %53
+         %53 = OpLabel
+               OpBranch %54
+         %54 = OpLabel
+               OpSelectionMerge %84 None
+               OpBranchConditional %true %55 %56
+         %55 = OpLabel
+               OpBranch %52
+         %56 = OpLabel
+               OpBranch %52
+         %84 = OpLabel
+               OpBranch %52
+         %52 = OpLabel
+         %85 = OpPhi %float %float_7 %55 %float_8 %56 %float_9 %84
+        %100 = OpFAdd %float %85 %85
+               OpReturn
+               OpFunctionEnd
+)",
+              R"(
+%main = @fragment func():void {
+  $B1: {
+    %2:f32 = switch 0u [c: (default, $B2)] {  # switch_1
+      $B2: {  # case
+        if true [t: $B3, f: $B4] {  # if_1
+          $B3: {  # true
+            exit_switch 7.0f  # switch_1
+          }
+          $B4: {  # false
+            exit_switch 8.0f  # switch_1
+          }
+        }
+        exit_switch 9.0f  # switch_1
+      }
+    }
+    %3:f32 = add %2, %2
+    ret
+  }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Phi_Loop_FromIfBreak) {
+    EXPECT_IR(R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpName %main "main"
+       %void = OpTypeVoid
+          %9 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+       %bool = OpTypeBool
+    %float_7 = OpConstant %float 7
+    %float_8 = OpConstant %float 8
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %9
+         %45 = OpLabel
+               OpBranch %50
+         %50 = OpLabel
+               OpLoopMerge %52 %55 None
+               OpBranch %60
+         %55 = OpLabel
+               OpBranch %50
+         %60 = OpLabel
+               OpSelectionMerge %84 None
+               OpBranchConditional %true %52 %84
+         %84 = OpLabel
+               OpBranch %52
+         %52 = OpLabel
+         %85 = OpPhi %float %float_7 %60 %float_8 %84
+        %100 = OpFAdd %float %85 %85
+               OpReturn
+               OpFunctionEnd
+
+)",
+              R"(
+%main = @fragment func():void {
+  $B1: {
+    %2:f32 = loop [b: $B2, c: $B3] {  # loop_1
+      $B2: {  # body
+        if true [t: $B4, f: $B5] {  # if_1
+          $B4: {  # true
+            exit_loop 7.0f  # loop_1
+          }
+          $B5: {  # false
+            exit_if  # if_1
+          }
+        }
+        exit_loop 8.0f  # loop_1
+      }
+      $B3: {  # continuing
+        next_iteration  # -> $B2
+      }
+    }
+    %3:f32 = add %2, %2
+    ret
+  }
+}
+)");
+}
+
 TEST_F(SpirvParserTest, Phi_Loop_ContinueIsHeader) {
     EXPECT_IR(R"(
                OpCapability Shader