[ir][spirv-writer] Implement switch instructions

Bug: tint:1906
Change-Id: Ie7de641227e4a2de2b812a8c79789c34af9f4e11
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134741
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 3305b8d..bc74d5d 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1972,6 +1972,7 @@
         "writer/spirv/ir/generator_impl_ir_function_test.cc",
         "writer/spirv/ir/generator_impl_ir_if_test.cc",
         "writer/spirv/ir/generator_impl_ir_loop_test.cc",
+        "writer/spirv/ir/generator_impl_ir_switch_test.cc",
         "writer/spirv/ir/generator_impl_ir_test.cc",
         "writer/spirv/ir/generator_impl_ir_type_test.cc",
         "writer/spirv/ir/generator_impl_ir_var_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index d8fe22a..277d819 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -1261,6 +1261,7 @@
         writer/spirv/ir/generator_impl_ir_function_test.cc
         writer/spirv/ir/generator_impl_ir_if_test.cc
         writer/spirv/ir/generator_impl_ir_loop_test.cc
+        writer/spirv/ir/generator_impl_ir_switch_test.cc
         writer/spirv/ir/generator_impl_ir_test.cc
         writer/spirv/ir/generator_impl_ir_type_test.cc
         writer/spirv/ir/generator_impl_ir_var_test.cc
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index 89a1a1c..a32f5b7 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -25,6 +25,7 @@
 #include "src/tint/ir/continue.h"
 #include "src/tint/ir/exit_if.h"
 #include "src/tint/ir/exit_loop.h"
+#include "src/tint/ir/exit_switch.h"
 #include "src/tint/ir/if.h"
 #include "src/tint/ir/load.h"
 #include "src/tint/ir/loop.h"
@@ -32,6 +33,7 @@
 #include "src/tint/ir/next_iteration.h"
 #include "src/tint/ir/return.h"
 #include "src/tint/ir/store.h"
+#include "src/tint/ir/switch.h"
 #include "src/tint/ir/transform/add_empty_entry_point.h"
 #include "src/tint/ir/user_call.h"
 #include "src/tint/ir/var.h"
@@ -343,6 +345,10 @@
                 EmitLoop(l);
                 return 0u;
             },
+            [&](const ir::Switch* sw) {
+                EmitSwitch(sw);
+                return 0u;
+            },
             [&](const ir::Store* s) {
                 EmitStore(s);
                 return 0u;
@@ -397,6 +403,9 @@
         [&](const ir::ExitLoop* loop) {
             current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Merge())});
         },
+        [&](const ir::ExitSwitch* swtch) {
+            current_function_.push_inst(spv::Op::OpBranch, {Label(swtch->Switch()->Merge())});
+        },
         [&](const ir::NextIteration* loop) {
             current_function_.push_inst(spv::Op::OpBranch, {Label(loop->Loop()->Start())});
         },
@@ -651,6 +660,45 @@
     EmitBlock(loop->Merge());
 }
 
+void GeneratorImplIr::EmitSwitch(const ir::Switch* swtch) {
+    // Find the default selector. There must be exactly one.
+    uint32_t default_label = 0u;
+    for (auto& c : swtch->Cases()) {
+        for (auto& sel : c.selectors) {
+            if (sel.IsDefault()) {
+                default_label = Label(c.Start());
+            }
+        }
+    }
+    TINT_ASSERT(Writer, default_label != 0u);
+
+    // Build the operands to the OpSwitch instruction.
+    OperandList switch_operands = {Value(swtch->Condition()), default_label};
+    for (auto& c : swtch->Cases()) {
+        auto label = Label(c.Start());
+        for (auto& sel : c.selectors) {
+            if (sel.IsDefault()) {
+                continue;
+            }
+            switch_operands.push_back(sel.val->Value()->ValueAs<uint32_t>());
+            switch_operands.push_back(label);
+        }
+    }
+
+    // Emit the OpSelectionMerge and OpSwitch instructions.
+    current_function_.push_inst(spv::Op::OpSelectionMerge,
+                                {Label(swtch->Merge()), U32Operand(SpvSelectionControlMaskNone)});
+    current_function_.push_inst(spv::Op::OpSwitch, switch_operands);
+
+    // Emit the cases.
+    for (auto& c : swtch->Cases()) {
+        EmitBlock(c.Start());
+    }
+
+    // Emit the switch merge block.
+    EmitBlock(swtch->Merge());
+}
+
 void GeneratorImplIr::EmitStore(const ir::Store* store) {
     current_function_.push_inst(spv::Op::OpStore, {Value(store->To()), Value(store->From())});
 }
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index 5293fea..3c8783e 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -39,6 +39,7 @@
 class Loop;
 class Module;
 class Store;
+class Switch;
 class UserCall;
 class Value;
 class Var;
@@ -130,6 +131,10 @@
     /// @param store the store instruction to emit
     void EmitStore(const ir::Store* store);
 
+    /// Emit a switch instruction.
+    /// @param swtch the switch instruction to emit
+    void EmitSwitch(const ir::Switch* swtch);
+
     /// Emit a user call instruction.
     /// @param call the user call instruction to emit
     /// @returns the result ID of the instruction
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
new file mode 100644
index 0000000..5ef5901
--- /dev/null
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_switch_test.cc
@@ -0,0 +1,221 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/writer/spirv/ir/test_helper_ir.h"
+
+using namespace tint::number_suffixes;  // NOLINT
+
+namespace tint::writer::spirv {
+namespace {
+
+TEST_F(SpvGeneratorImplTest, Switch_Basic) {
+    auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+    auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+    auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+    def_case->Instructions().Push(b.ExitSwitch(swtch));
+
+    swtch->Merge()->Instructions().Push(b.Return(func));
+
+    func->StartTarget()->Instructions().Push(swtch);
+
+    generator_.EmitFunction(func);
+    EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %8 None
+OpSwitch %6 %5
+%5 = OpLabel
+OpBranch %8
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_MultipleCases) {
+    auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+    auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+    auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+    case_a->Instructions().Push(b.ExitSwitch(swtch));
+
+    auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+    case_b->Instructions().Push(b.ExitSwitch(swtch));
+
+    auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+    def_case->Instructions().Push(b.ExitSwitch(swtch));
+
+    swtch->Merge()->Instructions().Push(b.Return(func));
+
+    func->StartTarget()->Instructions().Push(swtch);
+
+    generator_.EmitFunction(func);
+    EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %10 None
+OpSwitch %6 %5 1 %8 2 %9
+%8 = OpLabel
+OpBranch %10
+%9 = OpLabel
+OpBranch %10
+%5 = OpLabel
+OpBranch %10
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_MultipleSelectorsPerCase) {
+    auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+    auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+    auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)},
+                                                     ir::Switch::CaseSelector{b.Constant(3_i)}});
+    case_a->Instructions().Push(b.ExitSwitch(swtch));
+
+    auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)},
+                                                     ir::Switch::CaseSelector{b.Constant(4_i)}});
+    case_b->Instructions().Push(b.ExitSwitch(swtch));
+
+    auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(5_i)},
+                                                       ir::Switch::CaseSelector()});
+    def_case->Instructions().Push(b.ExitSwitch(swtch));
+
+    swtch->Merge()->Instructions().Push(b.Return(func));
+
+    func->StartTarget()->Instructions().Push(swtch);
+
+    generator_.EmitFunction(func);
+    EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %10 None
+OpSwitch %6 %5 1 %8 3 %8 2 %9 4 %9 5 %5
+%8 = OpLabel
+OpBranch %10
+%9 = OpLabel
+OpBranch %10
+%5 = OpLabel
+OpBranch %10
+%10 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_AllCasesReturn) {
+    auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+    auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+    auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+    case_a->Instructions().Push(b.Return(func));
+
+    auto* case_b = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(2_i)}});
+    case_b->Instructions().Push(b.Return(func));
+
+    auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+    def_case->Instructions().Push(b.Return(func));
+
+    func->StartTarget()->Instructions().Push(swtch);
+
+    generator_.EmitFunction(func);
+    EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %10 None
+OpSwitch %6 %5 1 %8 2 %9
+%8 = OpLabel
+OpReturn
+%9 = OpLabel
+OpReturn
+%5 = OpLabel
+OpReturn
+%10 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, Switch_ConditionalBreak) {
+    auto* func = b.CreateFunction("foo", mod.Types().void_());
+
+    auto* swtch = b.CreateSwitch(b.Constant(42_i));
+
+    auto* cond_break = b.CreateIf(b.Constant(true));
+    cond_break->True()->Instructions().Push(b.ExitSwitch(swtch));
+    cond_break->False()->Instructions().Push(b.ExitIf(cond_break));
+    cond_break->Merge()->Instructions().Push(b.Return(func));
+
+    auto* case_a = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector{b.Constant(1_i)}});
+    case_a->Instructions().Push(cond_break);
+
+    auto* def_case = b.CreateCase(swtch, utils::Vector{ir::Switch::CaseSelector()});
+    def_case->Instructions().Push(b.ExitSwitch(swtch));
+
+    swtch->Merge()->Instructions().Push(b.Return(func));
+
+    func->StartTarget()->Instructions().Push(swtch);
+
+    generator_.EmitFunction(func);
+    EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%7 = OpTypeInt 32 1
+%6 = OpConstant %7 42
+%13 = OpTypeBool
+%12 = OpConstantTrue %13
+%1 = OpFunction %2 None %3
+%4 = OpLabel
+OpSelectionMerge %9 None
+OpSwitch %6 %5 1 %8
+%8 = OpLabel
+OpSelectionMerge %10 None
+OpBranchConditional %12 %11 %10
+%11 = OpLabel
+OpBranch %9
+%10 = OpLabel
+OpReturn
+%5 = OpLabel
+OpBranch %9
+%9 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+}  // namespace
+}  // namespace tint::writer::spirv