[ir][spirv-writer] Implement switch instructions
Bug: tint:1906
Change-Id: Ie7de641227e4a2de2b812a8c79789c34af9f4e11
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134741
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 3305b8d..bc74d5d 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1972,6 +1972,7 @@
"writer/spirv/ir/generator_impl_ir_function_test.cc",
"writer/spirv/ir/generator_impl_ir_if_test.cc",
"writer/spirv/ir/generator_impl_ir_loop_test.cc",
+ "writer/spirv/ir/generator_impl_ir_switch_test.cc",
"writer/spirv/ir/generator_impl_ir_test.cc",
"writer/spirv/ir/generator_impl_ir_type_test.cc",
"writer/spirv/ir/generator_impl_ir_var_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index d8fe22a..277d819 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1261,6 +1261,7 @@
writer/spirv/ir/generator_impl_ir_function_test.cc
writer/spirv/ir/generator_impl_ir_if_test.cc
writer/spirv/ir/generator_impl_ir_loop_test.cc
+ writer/spirv/ir/generator_impl_ir_switch_test.cc
writer/spirv/ir/generator_impl_ir_test.cc
writer/spirv/ir/generator_impl_ir_type_test.cc
writer/spirv/ir/generator_impl_ir_var_test.cc
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 89a1a1c..a32f5b7 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -25,6 +25,7 @@
#include "src/tint/ir/continue.h"
#include "src/tint/ir/exit_if.h"
#include "src/tint/ir/exit_loop.h"
+#include "src/tint/ir/exit_switch.h"
#include "src/tint/ir/if.h"
#include "src/tint/ir/load.h"
#include "src/tint/ir/loop.h"
@@ -32,6 +33,7 @@
#include "src/tint/ir/next_iteration.h"
#include "src/tint/ir/return.h"
#include "src/tint/ir/store.h"
+#include "src/tint/ir/switch.h"
#include "src/tint/ir/transform/add_empty_entry_point.h"
#include "src/tint/ir/user_call.h"
#include "src/tint/ir/var.h"
@@ -343,6 +345,10 @@
EmitLoop(l);
return 0u;
},
+ [&](const ir::Switch* sw) {
+ EmitSwitch(sw);
+ return 0u;
+ },
[&](const ir::Store* s) {
EmitStore(s);
return 0u;
@@ -397,6 +403,9 @@
[&](const ir::ExitLoop* loop) {
current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Merge())});
},
+ [&](const ir::ExitSwitch* swtch) {
+ current_function_.push_inst(spv::Op::OpBranch, {Label(swtch->Switch()->Merge())});
+ },
[&](const ir::NextIteration* loop) {
current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Start())});
},
@@ -651,6 +660,45 @@
EmitBlock(loop->Merge());
}
+void GeneratorImplIr::EmitSwitch(const ir::Switch* swtch) {
+ // Find the default selector. There must be exactly one.
+ uint32_t default_label = 0u;
+ for (auto& c : swtch->Cases()) {
+ for (auto& sel : c.selectors) {
+ if (sel.IsDefault()) {
+ default_label = Label(c.Start());
+ }
+ }
+ }
+ TINT_ASSERT(Writer, default_label != 0u);
+
+ // Build the operands to the OpSwitch instruction.
+ OperandList switch_operands = {Value(swtch->Condition()), default_label};
+ for (auto& c : swtch->Cases()) {
+ auto label = Label(c.Start());
+ for (auto& sel : c.selectors) {
+ if (sel.IsDefault()) {
+ continue;
+ }
+ switch_operands.push_back(sel.val->Value()->ValueAs<uint32_t>());
+ switch_operands.push_back(label);
+ }
+ }
+
+ // Emit the OpSelectionMerge and OpSwitch instructions.
+ current_function_.push_inst(spv::Op::OpSelectionMerge,
+ {Label(swtch->Merge()), U32Operand(SpvSelectionControlMaskNone)});
+ current_function_.push_inst(spv::Op::OpSwitch, switch_operands);
+
+ // Emit the cases.
+ for (auto& c : swtch->Cases()) {
+ EmitBlock(c.Start());
+ }
+
+ // Emit the switch merge block.
+ EmitBlock(swtch->Merge());
+}
+
void GeneratorImplIr::EmitStore(const ir::Store* store) {
current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
}
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index 5293fea..3c8783e 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -39,6 +39,7 @@
class Loop;
class Module;
class Store;
+class Switch;
class UserCall;
class Value;
class Var;
@@ -130,6 +131,10 @@
/// @param store the store instruction to emit
void EmitStore(const ir::Store* store);
+ /// Emit a switch instruction.
+ /// @param swtch the switch instruction to emit
+ void EmitSwitch(const ir::Switch* swtch);
+
/// Emit a user call instruction.
/// @param call the user call instruction to emit
/// @returns the result ID of the instruction
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
new file mode 100644
index 0000000..5ef5901
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
@@ -0,0 +1,221 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::writer::spirv {
+namespace {
+
+TEST_F(SpvGeneratorImplTest, Switch_Basic) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+ auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ def_case->Instructions().Push(b.ExitSwitch(swtch));
+
+ swtch->Merge()->Instructions().Push(b.Return(func));
+
+ func->StartTarget()->Instructions().Push(swtch);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %8 None
+OpSwitch %6 %5
+%5 = OpLabel
+OpBranch %8
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_MultipleCases) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+ auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ case_a->Instructions().Push(b.ExitSwitch(swtch));
+
+ auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ case_b->Instructions().Push(b.ExitSwitch(swtch));
+
+ auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ def_case->Instructions().Push(b.ExitSwitch(swtch));
+
+ swtch->Merge()->Instructions().Push(b.Return(func));
+
+ func->StartTarget()->Instructions().Push(swtch);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %10 None
+OpSwitch %6 %5 1 %8 2 %9
+%8 = OpLabel
+OpBranch %10
+%9 = OpLabel
+OpBranch %10
+%5 = OpLabel
+OpBranch %10
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_MultipleSelectorsPerCase) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+ auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+ ir::Switch::CaseSelector{b.Constant(3_i)}});
+ case_a->Instructions().Push(b.ExitSwitch(swtch));
+
+ auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)},
+ ir::Switch::CaseSelector{b.Constant(4_i)}});
+ case_b->Instructions().Push(b.ExitSwitch(swtch));
+
+ auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(5_i)},
+ ir::Switch::CaseSelector()});
+ def_case->Instructions().Push(b.ExitSwitch(swtch));
+
+ swtch->Merge()->Instructions().Push(b.Return(func));
+
+ func->StartTarget()->Instructions().Push(swtch);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %10 None
+OpSwitch %6 %5 1 %8 3 %8 2 %9 4 %9 5 %5
+%8 = OpLabel
+OpBranch %10
+%9 = OpLabel
+OpBranch %10
+%5 = OpLabel
+OpBranch %10
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_AllCasesReturn) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+ auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ case_a->Instructions().Push(b.Return(func));
+
+ auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+ case_b->Instructions().Push(b.Return(func));
+
+ auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ def_case->Instructions().Push(b.Return(func));
+
+ func->StartTarget()->Instructions().Push(swtch);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %10 None
+OpSwitch %6 %5 1 %8 2 %9
+%8 = OpLabel
+OpReturn
+%9 = OpLabel
+OpReturn
+%5 = OpLabel
+OpReturn
+%10 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_ConditionalBreak) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+ auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+ auto* cond_break = b.CreateIf(b.Constant(true));
+ cond_break->True()->Instructions().Push(b.ExitSwitch(swtch));
+ cond_break->False()->Instructions().Push(b.ExitIf(cond_break));
+ cond_break->Merge()->Instructions().Push(b.Return(func));
+
+ auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+ case_a->Instructions().Push(cond_break);
+
+ auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+ def_case->Instructions().Push(b.ExitSwitch(swtch));
+
+ swtch->Merge()->Instructions().Push(b.Return(func));
+
+ func->StartTarget()->Instructions().Push(swtch);
+
+ generator_.EmitFunction(func);
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%13 = OpTypeBool
+%12 = OpConstantTrue %13
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %9 None
+OpSwitch %6 %5 1 %8
+%8 = OpLabel
+OpSelectionMerge %10 None
+OpBranchConditional %12 %11 %10
+%11 = OpLabel
+OpBranch %9
+%10 = OpLabel
+OpReturn
+%5 = OpLabel
+OpBranch %9
+%9 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+} // namespace
+} // namespace tint::writer::spirv