[spirv-reader][ir] Fix incorrect switch/loop exits.
In the case of an `if` where one of the branches jumps to a switch/loop
merge, and the other branch goes to the if merge which also jumps to the
switch/loop merge, from an IR perspective those two SPIR-V blocks would
refer to the same IR block. In this case, we need to differentiate by
looking at the `BranchConditional` and determining if we need to handle
this special terminator decision.
Bug: 42250952
Change-Id: I19d289bc455d7714b62898afc88ea16bd013a6fa
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/251197
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc
index 1f6a86c..41b7c6d 100644
--- a/src/tint/lang/spirv/reader/parser/parser.cc
+++ b/src/tint/lang/spirv/reader/parser/parser.cc
@@ -2178,6 +2178,7 @@
auto* type = Type(inst.type_id());
auto* phi_spirv_block = spirv_context_->get_instr_block(&inst);
+ auto phi_blk_id = phi_spirv_block->id();
// The merge target (which is the OpPhi SPIR-V block) is a walk stop block
auto* loop = StopWalkingAt(phi_spirv_block->id())->As<core::ir::Loop>();
@@ -2191,9 +2192,34 @@
auto value_id = inst.GetSingleWordInOperand(i);
auto blk_id = inst.GetSingleWordInOperand(i + 1);
- auto value_blk_iter = spirv_id_to_block_.find(blk_id);
- auto* value_ir_blk = value_blk_iter->second;
- auto* term = value_ir_blk->Terminator();
+ core::ir::Terminator* term = nullptr;
+
+ // If the basic block ends in a branch conditional, then we really need to update the
+ // branch which goes to the current block, not the terminator of the block we emitted
+ // into. This difference is because the merge block will _also_ end up in the same IR
+ // block and we can't tell the difference if we just get the terminator.
+ const auto& bb = current_spirv_function_->FindBlock(blk_id);
+ auto* terminator = (*bb).terminator();
+ if (terminator->opcode() == spv::Op::OpBranchConditional) {
+ uint32_t true_id = terminator->GetSingleWordInOperand(1);
+ uint32_t false_id = terminator->GetSingleWordInOperand(2);
+
+ auto iter = branch_conditional_to_if_.find(terminator);
+ if (iter != branch_conditional_to_if_.end()) {
+ auto* if_ = iter->second;
+ if (true_id == phi_blk_id) {
+ term = if_->True()->Terminator();
+ } else if (false_id == phi_blk_id) {
+ term = if_->False()->Terminator();
+ }
+ }
+ }
+
+ if (term == nullptr) {
+ auto value_blk_iter = spirv_id_to_block_.find(blk_id);
+ auto* value_ir_blk = value_blk_iter->second;
+ term = value_ir_blk->Terminator();
+ }
if (term->Is<core::ir::Unreachable>()) {
default_value = value_id;
@@ -2496,25 +2522,54 @@
core::ir::Switch* ctrl = nullptr;
std::optional<core::ir::Value*> value_for_default_block = std::nullopt;
+ auto phi_blk_id = spirv_context_->get_instr_block(&inst)->id();
+
for (uint32_t i = 0; i < inst.NumInOperands(); i += 2) {
auto value_id = inst.GetSingleWordInOperand(i);
auto blk_id = inst.GetSingleWordInOperand(i + 1);
- auto value_blk_iter = spirv_id_to_block_.find(blk_id);
- auto* value_ir_blk = value_blk_iter->second;
+ core::ir::Terminator* term = nullptr;
- // In the case of a switch, the block can refer to the header of the switch, in this
- // case it means we don't have a default block and we jump over the switch itself, so we
- // need to insert this value into the terminator of the default block of the switch.
- //
- // We store this away as we may not know the switch yet, we need to wait until we get
- // the control instruction and do the work later.
- if (blk_id == header_id) {
- value_for_default_block = Value(value_id);
- continue;
+ // If the basic block ends in a branch conditional, then we really need to update the
+ // branch which goes to the current block, not the terminator of the block we emitted
+ // into. This difference is because the merge block will _also_ end up in the same IR
+ // block and we can't tell the difference if we just get the terminator.
+ const auto& bb = current_spirv_function_->FindBlock(blk_id);
+ auto* terminator = (*bb).terminator();
+ if (terminator->opcode() == spv::Op::OpBranchConditional) {
+ uint32_t true_id = terminator->GetSingleWordInOperand(1);
+ uint32_t false_id = terminator->GetSingleWordInOperand(2);
+
+ auto iter = branch_conditional_to_if_.find(terminator);
+ TINT_ASSERT(iter != branch_conditional_to_if_.end());
+
+ auto* if_ = iter->second;
+ if (true_id == phi_blk_id) {
+ term = if_->True()->Terminator();
+ } else if (false_id == phi_blk_id) {
+ term = if_->False()->Terminator();
+ }
}
- auto* term = value_ir_blk->Terminator();
+ if (term == nullptr) {
+ auto value_blk_iter = spirv_id_to_block_.find(blk_id);
+ auto* value_ir_blk = value_blk_iter->second;
+
+ // In the case of a switch, the block can refer to the header of the switch, in this
+ // case it means we don't have a default block and we jump over the switch itself,
+ // so we need to insert this value into the terminator of the default block of the
+ // switch.
+ //
+ // We store this away as we may not know the switch yet, we need to wait until we
+ // get the control instruction and do the work later.
+ if (blk_id == header_id) {
+ value_for_default_block = Value(value_id);
+ continue;
+ }
+
+ term = value_ir_blk->Terminator();
+ }
+
// Push a placeholder for the operand value at this point. We'll store away the
// terminator/index pair along with the required value and then fill it in at the end of
// the block emission. This is because a PHI can refer to a value which is defined after
@@ -3120,6 +3175,8 @@
auto* if_ = b_.If(cond);
EmitWithoutResult(if_);
+ branch_conditional_to_if_.insert({&inst, if_});
+
std::optional<uint32_t> merge_id = std::nullopt;
auto* merge_inst = bb.GetMergeInst();
@@ -3146,7 +3203,6 @@
premerge_if_ = b_.If(b_.Constant(true));
walk_stop_blocks_.insert({premerge_start_id.value(), premerge_if_});
}
-
if (auto* ctrl = StopWalkingAt(true_id)) {
auto* new_inst = EmitBranchStopBlock(ctrl, if_, if_->True(), true_id);
inst_to_spirv_block_[new_inst] = bb.id();
@@ -4066,6 +4122,9 @@
// Map of `var` to the access mode it was originally created with. This may be different from
// the current mode if we needed to set a default mode.
std::unordered_map<core::ir::Var*, core::Access> var_to_original_access_mode_;
+
+ // Map of spir-v branch conditional instructions to the related IR if instruction.
+ std::unordered_map<const spvtools::opt::Instruction*, core::ir::If*> branch_conditional_to_if_;
};
} // namespace
diff --git a/src/tint/lang/spirv/reader/parser/phi_test.cc b/src/tint/lang/spirv/reader/parser/phi_test.cc
index 1e07b99..865d62e 100644
--- a/src/tint/lang/spirv/reader/parser/phi_test.cc
+++ b/src/tint/lang/spirv/reader/parser/phi_test.cc
@@ -308,6 +308,307 @@
)");
}
+TEST_F(SpirvParserTest, Phi_Switch_FromIfBreak) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %int = OpTypeInt 32 1
+ %bool = OpTypeBool
+ %int_0 = OpConstant %int 0
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %int_3 = OpConstant %int 3
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %3
+ %4 = OpLabel
+ OpSelectionMerge %13 None
+ OpSwitch %int_0 %10 0 %11 1 %12
+ %10 = OpLabel
+ OpBranch %13
+ %11 = OpLabel
+ OpBranch %13
+ %12 = OpLabel
+ OpSelectionMerge %20 None
+ OpBranchConditional %true %20 %21
+ %21 = OpLabel
+ OpBranch %13
+ %20 = OpLabel
+ OpBranch %13
+ %13 = OpLabel
+ %14 = OpPhi %int %int_0 %10 %int_1 %11 %int_2 %20 %int_3 %21
+ %15 = OpIAdd %int %14 %14
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B1: {
+ %2:i32 = switch 0i [c: (default, $B2), c: (0i, $B3), c: (1i, $B4)] { # switch_1
+ $B2: { # case
+ exit_switch 0i # switch_1
+ }
+ $B3: { # case
+ exit_switch 1i # switch_1
+ }
+ $B4: { # case
+ if true [t: $B5, f: $B6] { # if_1
+ $B5: { # true
+ exit_if # if_1
+ }
+ $B6: { # false
+ exit_switch 3i # switch_1
+ }
+ }
+ exit_switch 2i # switch_1
+ }
+ }
+ %3:i32 = spirv.add<i32> %2, %2
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Phi_Switch_FromIfBreak_InDefault) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpName %main "main"
+ %void = OpTypeVoid
+ %9 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %bool = OpTypeBool
+ %float_7 = OpConstant %float 7
+ %float_8 = OpConstant %float 8
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %9
+ %45 = OpLabel
+ OpSelectionMerge %52 None
+ OpSwitch %uint_0 %53
+ %53 = OpLabel
+ OpBranch %54
+ %54 = OpLabel
+ OpSelectionMerge %84 None
+ OpBranchConditional %true %52 %84
+ %84 = OpLabel
+ OpBranch %52
+ %52 = OpLabel
+ %85 = OpPhi %float %float_7 %54 %float_8 %84
+ %100 = OpFAdd %float %85 %85
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @fragment func():void {
+ $B1: {
+ %2:f32 = switch 0u [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ if true [t: $B3, f: $B4] { # if_1
+ $B3: { # true
+ exit_switch 7.0f # switch_1
+ }
+ $B4: { # false
+ exit_if # if_1
+ }
+ }
+ exit_switch 8.0f # switch_1
+ }
+ }
+ %3:f32 = add %2, %2
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Phi_Switch_FromIf_BothJumpToMerge_InDefault) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpName %main "main"
+ %void = OpTypeVoid
+ %9 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %bool = OpTypeBool
+ %float_7 = OpConstant %float 7
+ %float_8 = OpConstant %float 8
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %9
+ %45 = OpLabel
+ OpSelectionMerge %52 None
+ OpSwitch %uint_0 %53
+ %53 = OpLabel
+ OpBranch %54
+ %54 = OpLabel
+ OpSelectionMerge %84 None
+ OpBranchConditional %true %52 %52
+ %84 = OpLabel
+ OpBranch %52
+ %52 = OpLabel
+ %85 = OpPhi %float %float_7 %54 %float_8 %84
+ %100 = OpFAdd %float %85 %85
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @fragment func():void {
+ $B1: {
+ %2:f32 = switch 0u [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %3:bool = or true, true
+ if %3 [t: $B3, f: $B4] { # if_1
+ $B3: { # true
+ exit_switch 7.0f # switch_1
+ }
+ $B4: { # false
+ unreachable
+ }
+ }
+ exit_switch 8.0f # switch_1
+ }
+ }
+ %4:f32 = add %2, %2
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Phi_Switch_FromIfBreakBoth_InDefault) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpName %main "main"
+ %void = OpTypeVoid
+ %9 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %bool = OpTypeBool
+ %float_7 = OpConstant %float 7
+ %float_8 = OpConstant %float 8
+ %float_9 = OpConstant %float 9
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %9
+ %45 = OpLabel
+ OpSelectionMerge %52 None
+ OpSwitch %uint_0 %53
+ %53 = OpLabel
+ OpBranch %54
+ %54 = OpLabel
+ OpSelectionMerge %84 None
+ OpBranchConditional %true %55 %56
+ %55 = OpLabel
+ OpBranch %52
+ %56 = OpLabel
+ OpBranch %52
+ %84 = OpLabel
+ OpBranch %52
+ %52 = OpLabel
+ %85 = OpPhi %float %float_7 %55 %float_8 %56 %float_9 %84
+ %100 = OpFAdd %float %85 %85
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+%main = @fragment func():void {
+ $B1: {
+ %2:f32 = switch 0u [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ if true [t: $B3, f: $B4] { # if_1
+ $B3: { # true
+ exit_switch 7.0f # switch_1
+ }
+ $B4: { # false
+ exit_switch 8.0f # switch_1
+ }
+ }
+ exit_switch 9.0f # switch_1
+ }
+ }
+ %3:f32 = add %2, %2
+ ret
+ }
+}
+)");
+}
+
+TEST_F(SpirvParserTest, Phi_Loop_FromIfBreak) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpName %main "main"
+ %void = OpTypeVoid
+ %9 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %bool = OpTypeBool
+ %float_7 = OpConstant %float 7
+ %float_8 = OpConstant %float 8
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %9
+ %45 = OpLabel
+ OpBranch %50
+ %50 = OpLabel
+ OpLoopMerge %52 %55 None
+ OpBranch %60
+ %55 = OpLabel
+ OpBranch %50
+ %60 = OpLabel
+ OpSelectionMerge %84 None
+ OpBranchConditional %true %52 %84
+ %84 = OpLabel
+ OpBranch %52
+ %52 = OpLabel
+ %85 = OpPhi %float %float_7 %60 %float_8 %84
+ %100 = OpFAdd %float %85 %85
+ OpReturn
+ OpFunctionEnd
+
+)",
+ R"(
+%main = @fragment func():void {
+ $B1: {
+ %2:f32 = loop [b: $B2, c: $B3] { # loop_1
+ $B2: { # body
+ if true [t: $B4, f: $B5] { # if_1
+ $B4: { # true
+ exit_loop 7.0f # loop_1
+ }
+ $B5: { # false
+ exit_if # if_1
+ }
+ }
+ exit_loop 8.0f # loop_1
+ }
+ $B3: { # continuing
+ next_iteration # -> $B2
+ }
+ }
+ %3:f32 = add %2, %2
+ ret
+ }
+}
+)");
+}
+
TEST_F(SpirvParserTest, Phi_Loop_ContinueIsHeader) {
EXPECT_IR(R"(
OpCapability Shader