[ir][spirv-writer] Fix OpPhi incoming blocks

Terminator instructions may be nested inside merge blocks, so we
cannot just use the label of the containing IR block for OpPhi
incoming block labels. Instead, walk backwards from the terminator to
find a preceding control instruction and use its merge block.

Bug: tint:1906
Change-Id: Ifcd30f8492d992742d72ec47c6cb4d5aac207b9d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/151500
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/writer/if_test.cc b/src/tint/lang/spirv/writer/if_test.cc
index 152e583..50cf1e5 100644
--- a/src/tint/lang/spirv/writer/if_test.cc
+++ b/src/tint/lang/spirv/writer/if_test.cc
@@ -288,5 +288,49 @@
 )");
 }
 
+TEST_F(SpirvWriterTest, If_Phi_Nested) {
+    auto* func = b.Function("foo", ty.i32());
+    b.Append(func->Block(), [&] {
+        auto* outer = b.If(true);
+        outer->SetResults(b.InstructionResult(ty.i32()));
+        b.Append(outer->True(), [&] {  //
+            auto* inner = b.If(true);
+            inner->SetResults(b.InstructionResult(ty.i32()));
+            b.Append(inner->True(), [&] {  //
+                b.ExitIf(inner, 10_i);
+            });
+            b.Append(inner->False(), [&] {  //
+                b.ExitIf(inner, 20_i);
+            });
+            b.ExitIf(outer, inner->Result());
+        });
+        b.Append(outer->False(), [&] {  //
+            b.ExitIf(outer, 30_i);
+        });
+        b.Return(func, outer);
+    });
+
+    ASSERT_TRUE(Generate()) << Error() << output_;
+    EXPECT_INST(R"(
+               OpSelectionMerge %5 None
+               OpBranchConditional %true %6 %7
+          %6 = OpLabel
+               OpSelectionMerge %10 None
+               OpBranchConditional %true %11 %12
+         %11 = OpLabel
+               OpBranch %10
+         %12 = OpLabel
+               OpBranch %10
+         %10 = OpLabel
+         %13 = OpPhi %int %int_10 %11 %int_20 %12
+               OpBranch %5
+          %7 = OpLabel
+               OpBranch %5
+          %5 = OpLabel
+         %16 = OpPhi %int %int_30 %7 %13 %10
+               OpReturnValue %16
+)");
+}
+
 }  // namespace
 }  // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/loop_test.cc b/src/tint/lang/spirv/writer/loop_test.cc
index 1a44fdd..61760b5 100644
--- a/src/tint/lang/spirv/writer/loop_test.cc
+++ b/src/tint/lang/spirv/writer/loop_test.cc
@@ -330,7 +330,7 @@
         auto* loop = b.Loop();
 
         b.Append(loop->Initializer(), [&] {  //
-            b.NextIteration(loop, 1_i, false);
+            b.NextIteration(loop, 1_i);
         });
 
         auto* loop_param = b.BlockParam(ty.i32());
@@ -427,5 +427,137 @@
 )");
 }
 
+TEST_F(SpirvWriterTest, Loop_Phi_NestedIf) {
+    auto* func = b.Function("foo", ty.void_());
+
+    b.Append(func->Block(), [&] {
+        auto* loop = b.Loop();
+        b.Append(loop->Initializer(), [&] {  //
+            b.NextIteration(loop, 1_i);
+        });
+
+        auto* loop_param = b.BlockParam(ty.i32());
+        loop->Body()->SetParams({loop_param});
+        b.Append(loop->Body(), [&] {
+            auto* inner = b.If(true);
+            inner->SetResults(b.InstructionResult(ty.i32()));
+            b.Append(inner->True(), [&] {  //
+                b.ExitIf(inner, 10_i);
+            });
+            b.Append(inner->False(), [&] {  //
+                b.ExitIf(inner, 20_i);
+            });
+            b.Continue(loop, inner->Result());
+        });
+
+        auto* cont_param = b.BlockParam(ty.i32());
+        loop->Continuing()->SetParams({cont_param});
+        b.Append(loop->Continuing(), [&] {
+            auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
+            b.BreakIf(loop, cmp, cont_param);
+        });
+
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << Error() << output_;
+    EXPECT_INST(R"(
+          %4 = OpLabel
+               OpBranch %5
+          %5 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpPhi %int %int_1 %5 %13 %7
+               OpLoopMerge %9 %7 None
+               OpBranch %6
+          %6 = OpLabel
+               OpSelectionMerge %14 None
+               OpBranchConditional %true %15 %16
+         %15 = OpLabel
+               OpBranch %14
+         %16 = OpLabel
+               OpBranch %14
+         %14 = OpLabel
+         %19 = OpPhi %int %int_10 %15 %int_20 %16
+               OpBranch %7
+          %7 = OpLabel
+         %13 = OpPhi %int %19 %14
+         %22 = OpSGreaterThan %bool %13 %int_5
+               OpBranchConditional %22 %9 %8
+          %9 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)");
+}
+
+TEST_F(SpirvWriterTest, Loop_Phi_NestedLoop) {
+    auto* func = b.Function("foo", ty.void_());
+
+    b.Append(func->Block(), [&] {
+        auto* outer = b.Loop();
+        b.Append(outer->Initializer(), [&] {  //
+            b.NextIteration(outer, 1_i);
+        });
+
+        auto* outer_param = b.BlockParam(ty.i32());
+        outer->Body()->SetParams({outer_param});
+        b.Append(outer->Body(), [&] {
+            auto* inner = b.Loop();
+            b.Append(inner->Initializer(), [&] {  //
+                b.NextIteration(inner);
+            });
+            b.Append(inner->Body(), [&] {  //
+                b.Continue(inner);
+            });
+            b.Append(inner->Continuing(), [&] {  //
+                b.BreakIf(inner, true);
+            });
+
+            b.Continue(outer, outer_param);
+        });
+
+        auto* cont_param = b.BlockParam(ty.i32());
+        outer->Continuing()->SetParams({cont_param});
+        b.Append(outer->Continuing(), [&] {
+            auto* cmp = b.GreaterThan(ty.bool_(), cont_param, 5_i);
+            b.BreakIf(outer, cmp, cont_param);
+        });
+
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate()) << Error() << output_;
+    EXPECT_INST(R"(
+          %4 = OpLabel
+               OpBranch %5
+          %5 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpPhi %int %int_1 %5 %13 %7
+               OpLoopMerge %9 %7 None
+               OpBranch %6
+          %6 = OpLabel
+               OpBranch %14
+         %14 = OpLabel
+               OpBranch %17
+         %17 = OpLabel
+               OpLoopMerge %18 %16 None
+               OpBranch %15
+         %15 = OpLabel
+               OpBranch %16
+         %16 = OpLabel
+               OpBranchConditional %true %18 %17
+         %18 = OpLabel
+               OpBranch %7
+          %7 = OpLabel
+         %13 = OpPhi %int %11 %18
+         %21 = OpSGreaterThan %bool %13 %int_5
+               OpBranchConditional %21 %9 %8
+          %9 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)");
+}
+
 }  // namespace
 }  // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index eba5239..b098a86 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -748,7 +748,7 @@
         for (auto* incoming : block->InboundSiblingBranches()) {
             auto* arg = incoming->Args()[param_idx];
             ops.push_back(Value(arg));
-            ops.push_back(Label(incoming->Block()));
+            ops.push_back(GetTerminatorBlockLabel(incoming));
         }
 
         current_function_.push_inst(spv::Op::OpPhi, std::move(ops));
@@ -848,7 +848,7 @@
     // 2. branches somewhere instead of exiting the loop (e.g. return or break), or
     // 3. the if returns a value
     // Otherwise we skip them and branch straight to the merge block.
-    uint32_t merge_label = module_.NextId();
+    uint32_t merge_label = GetMergeLabel(i);
     TINT_SCOPED_ASSIGNMENT(if_merge_label_, merge_label);
 
     uint32_t true_label = merge_label;
@@ -1697,7 +1697,7 @@
     auto header_label = module_.NextId();
     TINT_SCOPED_ASSIGNMENT(loop_header_label_, header_label);
 
-    auto merge_label = module_.NextId();
+    auto merge_label = GetMergeLabel(loop);
     TINT_SCOPED_ASSIGNMENT(loop_merge_label_, merge_label);
 
     if (init_label != 0) {
@@ -1762,7 +1762,7 @@
         }
     }
 
-    uint32_t merge_label = module_.NextId();
+    uint32_t merge_label = GetMergeLabel(swtch);
     TINT_SCOPED_ASSIGNMENT(switch_merge_label_, merge_label);
 
     // Emit the OpSelectionMerge and OpSwitch instructions.
@@ -1935,7 +1935,7 @@
         Vector<Branch, 8> branches;
         branches.Reserve(inst->Exits().Count());
         for (auto& exit : inst->Exits()) {
-            branches.Push(Branch{Label(exit->Block()), exit->Args()[index]});
+            branches.Push(Branch{GetTerminatorBlockLabel(exit), exit->Args()[index]});
         }
         branches.Sort();  // Sort the branches by label to ensure deterministic output
 
@@ -1952,6 +1952,26 @@
     }
 }
 
+uint32_t Printer::GetMergeLabel(core::ir::ControlInstruction* ci) {
+    return merge_block_labels_.GetOrCreate(ci, [&] { return module_.NextId(); });
+}
+
+uint32_t Printer::GetTerminatorBlockLabel(core::ir::Terminator* t) {
+    // Walk backwards from `t` until we find a control instruction.
+    auto* inst = t->prev;
+    while (inst) {
+        auto* prev = inst->prev;
+        if (auto* ci = inst->As<core::ir::ControlInstruction>()) {
+            // This is the last control instruction before `t`, so use its merge block label.
+            return GetMergeLabel(ci);
+        }
+        inst = prev;
+    }
+
+    // There were no control instructions before `t`, so use the label of the parent block.
+    return Label(t->Block());
+}
+
 uint32_t Printer::TexelFormat(const core::TexelFormat format) {
     switch (format) {
         case core::TexelFormat::kBgra8Unorm:
diff --git a/src/tint/lang/spirv/writer/printer/printer.h b/src/tint/lang/spirv/writer/printer/printer.h
index f991bfe..2f0d53a 100644
--- a/src/tint/lang/spirv/writer/printer/printer.h
+++ b/src/tint/lang/spirv/writer/printer/printer.h
@@ -272,6 +272,16 @@
     /// @param inst the flow control instruction
     void EmitExitPhis(core::ir::ControlInstruction* inst);
 
+    /// Get the ID of the label of the merge block for a control instruction.
+    /// @param ci the control instruction to get the merge label for
+    /// @returns the label ID
+    uint32_t GetMergeLabel(core::ir::ControlInstruction* ci);
+
+    /// Get the ID of the label of the block that will contain a terminator instruction.
+    /// @param t the terminator instruction to get the block label for
+    /// @returns the label ID
+    uint32_t GetTerminatorBlockLabel(core::ir::Terminator* t);
+
     core::ir::Module* ir_;
     core::ir::Builder b_;
     writer::Module module_;
@@ -323,6 +333,9 @@
     /// The map of blocks to the IDs of their label instructions.
     Hashmap<core::ir::Block*, uint32_t, 8> block_labels_;
 
+    /// The map of control instructions to the IDs of the label of their SPIR-V merge blocks.
+    Hashmap<core::ir::ControlInstruction*, uint32_t, 8> merge_block_labels_;
+
     /// The map of extended instruction set names to their result IDs.
     Hashmap<std::string_view, uint32_t, 2> imports_;
 
diff --git a/src/tint/lang/spirv/writer/switch_test.cc b/src/tint/lang/spirv/writer/switch_test.cc
index bfcefca..ef94384 100644
--- a/src/tint/lang/spirv/writer/switch_test.cc
+++ b/src/tint/lang/spirv/writer/switch_test.cc
@@ -369,5 +369,104 @@
 )");
 }
 
+TEST_F(SpirvWriterTest, Switch_Phi_NestedIf) {
+    auto* func = b.Function("foo", ty.i32());
+    b.Append(func->Block(), [&] {
+        auto* s = b.Switch(42_i);
+        s->SetResults(b.InstructionResult(ty.i32()));
+        auto* case_a = b.Case(s, Vector{core::ir::Switch::CaseSelector{b.Constant(1_i)},
+                                        core::ir::Switch::CaseSelector{nullptr}});
+        b.Append(case_a, [&] {  //
+            auto* inner = b.If(true);
+            inner->SetResults(b.InstructionResult(ty.i32()));
+            b.Append(inner->True(), [&] {  //
+                b.ExitIf(inner, 10_i);
+            });
+            b.Append(inner->False(), [&] {  //
+                b.ExitIf(inner, 20_i);
+            });
+
+            b.ExitSwitch(s, inner->Result());
+        });
+
+        auto* case_b = b.Case(s, Vector{core::ir::Switch::CaseSelector{b.Constant(2_i)}});
+        b.Append(case_b, [&] {  //
+            b.ExitSwitch(s, 20_i);
+        });
+
+        b.Return(func, s);
+    });
+
+    ASSERT_TRUE(Generate()) << Error() << output_;
+    EXPECT_INST(R"(
+          %4 = OpLabel
+               OpSelectionMerge %8 None
+               OpSwitch %int_42 %5 1 %5 2 %7
+          %5 = OpLabel
+               OpSelectionMerge %9 None
+               OpBranchConditional %true %10 %11
+         %10 = OpLabel
+               OpBranch %9
+         %11 = OpLabel
+               OpBranch %9
+          %9 = OpLabel
+         %14 = OpPhi %int %int_10 %10 %int_20 %11
+               OpBranch %8
+          %7 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %17 = OpPhi %int %int_20 %7 %14 %9
+               OpReturnValue %17
+               OpFunctionEnd
+)");
+}
+
+TEST_F(SpirvWriterTest, Switch_Phi_NestedSwitch) {
+    auto* func = b.Function("foo", ty.i32());
+    b.Append(func->Block(), [&] {
+        auto* outer = b.Switch(42_i);
+        outer->SetResults(b.InstructionResult(ty.i32()));
+        auto* case_a = b.Case(outer, Vector{core::ir::Switch::CaseSelector{b.Constant(1_i)},
+                                            core::ir::Switch::CaseSelector{nullptr}});
+        b.Append(case_a, [&] {  //
+            auto* inner = b.Switch(42_i);
+            auto* case_inner = b.Case(inner, Vector{core::ir::Switch::CaseSelector{b.Constant(2_i)},
+                                                    core::ir::Switch::CaseSelector{nullptr}});
+            b.Append(case_inner, [&] {  //
+                b.ExitSwitch(inner);
+            });
+
+            b.ExitSwitch(outer, 10_i);
+        });
+
+        auto* case_b = b.Case(outer, Vector{core::ir::Switch::CaseSelector{b.Constant(2_i)}});
+        b.Append(case_b, [&] {  //
+            b.ExitSwitch(outer, 20_i);
+        });
+
+        b.Return(func, outer);
+    });
+
+    ASSERT_TRUE(Generate()) << Error() << output_;
+    EXPECT_INST(R"(
+          %4 = OpLabel
+               OpSelectionMerge %8 None
+               OpSwitch %int_42 %5 1 %5 2 %7
+          %5 = OpLabel
+               OpSelectionMerge %10 None
+               OpSwitch %int_42 %9 2 %9
+          %9 = OpLabel
+               OpBranch %10
+         %10 = OpLabel
+               OpBranch %8
+          %7 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpPhi %int %int_20 %7 %int_10 %10
+               OpReturnValue %11
+               OpFunctionEnd
+)");
+}
+
 }  // namespace
 }  // namespace tint::spirv::writer