[tint][ir] Serialize Switch instructions

Change-Id: I3741dd6cdde76ad8f2773d6c1011aa8cc372fcce
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/165004
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index e69373b..33a00e9 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -53,6 +53,7 @@
     Builder b{mod_out_};
 
     Vector<ir::ExitIf*, 32> exit_ifs_{};
+    Vector<ir::ExitSwitch*, 32> exit_switches_{};
 
     void Decode() {
         {
@@ -106,6 +107,9 @@
         for (auto* exit : exit_ifs_) {
             InferControlInstruction(exit, &ExitIf::SetIf);
         }
+        for (auto* exit : exit_switches_) {
+            InferControlInstruction(exit, &ExitSwitch::SetSwitch);
+        }
     }
 
     template <typename EXIT, typename CTRL_INST>
@@ -204,6 +208,9 @@
             case pb::Instruction::KindCase::kExitIf:
                 inst_out = CreateInstructionExitIf(inst_in.exit_if());
                 break;
+            case pb::Instruction::KindCase::kExitSwitch:
+                inst_out = CreateInstructionExitSwitch(inst_in.exit_switch());
+                break;
             case pb::Instruction::KindCase::kDiscard:
                 inst_out = CreateInstructionDiscard(inst_in.discard());
                 break;
@@ -231,6 +238,9 @@
             case pb::Instruction::KindCase::kSwizzle:
                 inst_out = CreateInstructionSwizzle(inst_in.swizzle());
                 break;
+            case pb::Instruction::KindCase::kSwitch:
+                inst_out = CreateInstructionSwitch(inst_in.switch_());
+                break;
             case pb::Instruction::KindCase::kUnary:
                 inst_out = CreateInstructionUnary(inst_in.unary());
                 break;
@@ -291,6 +301,12 @@
         return exit_out;
     }
 
+    ir::ExitSwitch* CreateInstructionExitSwitch(const pb::InstructionExitSwitch&) {
+        auto* exit_out = mod_out_.instructions.Create<ir::ExitSwitch>();
+        exit_switches_.Push(exit_out);
+        return exit_out;
+    }
+
     ir::Discard* CreateInstructionDiscard(const pb::InstructionDiscard&) {
         return mod_out_.instructions.Create<ir::Discard>();
     }
@@ -342,6 +358,26 @@
         return swizzle_out;
     }
 
+    ir::Switch* CreateInstructionSwitch(const pb::InstructionSwitch& switch_in) {
+        auto* switch_out = mod_out_.instructions.Create<ir::Switch>();
+        for (auto& case_in : switch_in.cases()) {
+            ir::Switch::Case case_out{};
+            case_out.block = Block(case_in.block());
+            case_out.block->SetParent(switch_out);
+            for (auto selector_in : case_in.selectors()) {
+                ir::Switch::CaseSelector selector_out{};
+                selector_out.val = b.Constant(ConstantValue(selector_in));
+                case_out.selectors.Push(std::move(selector_out));
+            }
+            if (case_in.is_default()) {
+                ir::Switch::CaseSelector selector_out{};
+                case_out.selectors.Push(std::move(selector_out));
+            }
+            switch_out->Cases().Push(std::move(case_out));
+        }
+        return switch_out;
+    }
+
     ir::Unary* CreateInstructionUnary(const pb::InstructionUnary& unary_in) {
         auto* unary_out = mod_out_.instructions.Create<ir::Unary>();
         unary_out->SetOp(UnaryOp(unary_in.op()));
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index e8910d1..c9ee44c 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -41,6 +41,7 @@
 #include "src/tint/lang/core/ir/core_builtin_call.h"
 #include "src/tint/lang/core/ir/discard.h"
 #include "src/tint/lang/core/ir/exit_if.h"
+#include "src/tint/lang/core/ir/exit_switch.h"
 #include "src/tint/lang/core/ir/function_param.h"
 #include "src/tint/lang/core/ir/if.h"
 #include "src/tint/lang/core/ir/let.h"
@@ -50,6 +51,7 @@
 #include "src/tint/lang/core/ir/return.h"
 #include "src/tint/lang/core/ir/store.h"
 #include "src/tint/lang/core/ir/store_vector_element.h"
+#include "src/tint/lang/core/ir/switch.h"
 #include "src/tint/lang/core/ir/swizzle.h"
 #include "src/tint/lang/core/ir/unary.h"
 #include "src/tint/lang/core/ir/user_call.h"
@@ -154,7 +156,7 @@
     // Instructions
     ////////////////////////////////////////////////////////////////////////////
     void Instruction(pb::Instruction& inst_out, const ir::Instruction* inst_in) {
-        Switch(
+        tint::Switch(
             inst_in,  //
             [&](const ir::Access* i) { InstructionAccess(*inst_out.mutable_access(), i); },
             [&](const ir::Binary* i) { InstructionBinary(*inst_out.mutable_binary(), i); },
@@ -165,6 +167,9 @@
             [&](const ir::Convert* i) { InstructionConvert(*inst_out.mutable_convert(), i); },
             [&](const ir::Discard* i) { InstructionDiscard(*inst_out.mutable_discard(), i); },
             [&](const ir::ExitIf* i) { InstructionExitIf(*inst_out.mutable_exit_if(), i); },
+            [&](const ir::ExitSwitch* i) {
+                InstructionExitSwitch(*inst_out.mutable_exit_switch(), i);
+            },
             [&](const ir::If* i) { InstructionIf(*inst_out.mutable_if_(), i); },
             [&](const ir::Let* i) { InstructionLet(*inst_out.mutable_let(), i); },
             [&](const ir::Load* i) { InstructionLoad(*inst_out.mutable_load(), i); },
@@ -176,6 +181,7 @@
             [&](const ir::StoreVectorElement* i) {
                 InstructionStoreVectorElement(*inst_out.mutable_store_vector_element(), i);
             },
+            [&](const ir::Switch* i) { InstructionSwitch(*inst_out.mutable_switch_(), i); },
             [&](const ir::Swizzle* i) { InstructionSwizzle(*inst_out.mutable_swizzle(), i); },
             [&](const ir::Unary* i) { InstructionUnary(*inst_out.mutable_unary(), i); },
             [&](const ir::UserCall* i) { InstructionUserCall(*inst_out.mutable_user_call(), i); },
@@ -217,6 +223,8 @@
 
     void InstructionExitIf(pb::InstructionExitIf&, const ir::ExitIf*) {}
 
+    void InstructionExitSwitch(pb::InstructionExitSwitch&, const ir::ExitSwitch*) {}
+
     void InstructionLet(pb::InstructionLet&, const ir::Let*) {}
 
     void InstructionLoad(pb::InstructionLoad&, const ir::Load*) {}
@@ -237,6 +245,20 @@
         }
     }
 
+    void InstructionSwitch(pb::InstructionSwitch& switch_out, const ir::Switch* switch_in) {
+        for (auto& case_in : switch_in->Cases()) {
+            auto& case_out = *switch_out.add_cases();
+            case_out.set_block(Block(case_in.block));
+            for (auto& selector_in : case_in.selectors) {
+                if (selector_in.IsDefault()) {
+                    case_out.set_is_default(true);
+                } else {
+                    case_out.add_selectors(ConstantValue(selector_in.val->Value()));
+                }
+            }
+        }
+    }
+
     void InstructionUnary(pb::InstructionUnary& unary_out, const ir::Unary* unary_in) {
         unary_out.set_op(UnaryOp(unary_in->Op()));
     }
@@ -260,7 +282,7 @@
         }
         return types_.GetOrCreate(type_in, [&]() -> uint32_t {
             pb::Type type_out;
-            Switch(
+            tint::Switch(
                 type_in,  //
                 [&](const core::type::Void*) { type_out.set_basic(pb::TypeBasic::void_); },
                 [&](const core::type::Bool*) { type_out.set_basic(pb::TypeBasic::bool_); },
@@ -341,7 +363,7 @@
     void TypeArray(pb::TypeArray& array_out, const core::type::Array* array_in) {
         array_out.set_element(Type(array_in->ElemType()));
         array_out.set_stride(array_in->Stride());
-        Switch(
+        tint::Switch(
             array_in->Count(),  //
             [&](const core::type::ConstantArrayCount* c) { array_out.set_count(c->value); },
             TINT_ICE_ON_NO_MATCH);
@@ -356,7 +378,7 @@
         }
         return values_.GetOrCreate(value_in, [&] {
             auto& value_out = *mod_out_.add_values();
-            Switch(
+            tint::Switch(
                 value_in,
                 [&](const ir::InstructionResult* v) {
                     InstructionResult(*value_out.mutable_instruction_result(), v);
@@ -395,7 +417,7 @@
         }
         return constant_values_.GetOrCreate(constant_in, [&] {
             pb::ConstantValue constant_out;
-            Switch(
+            tint::Switch(
                 constant_in,  //
                 [&](const core::constant::Scalar<bool>* b) {
                     constant_out.mutable_scalar()->set_bool_(b->value);
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index 95a3d21..70ce8ae 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -215,7 +215,9 @@
         InstructionStoreVectorElement store_vector_element = 18;
         InstructionSwizzle swizzle = 19;
         InstructionIf if = 20;
-        InstructionExitIf exit_if = 21;
+        InstructionSwitch switch = 21;
+        InstructionExitIf exit_if = 22;
+        InstructionExitSwitch exit_switch = 23;
     }
 }
 
@@ -270,8 +272,20 @@
     optional uint32 false = 2;  // Module.blocks
 }
 
+message InstructionSwitch {
+    repeated SwitchCase cases = 1;  // Module.blocks
+}
+
 message InstructionExitIf {}
 
+message InstructionExitSwitch {}
+
+message SwitchCase {
+    uint32 block = 1;               // Module.blocks
+    repeated uint32 selectors = 2;  // Module.constant_values
+    bool is_default = 3;
+}
+
 message BindingPoint {
     uint32 group = 1;
     uint32 binding = 2;
diff --git a/src/tint/lang/core/ir/binary/roundtrip_test.cc b/src/tint/lang/core/ir/binary/roundtrip_test.cc
index 6fbb02a..28d5e7c 100644
--- a/src/tint/lang/core/ir/binary/roundtrip_test.cc
+++ b/src/tint/lang/core/ir/binary/roundtrip_test.cc
@@ -497,5 +497,35 @@
     RUN_TEST();
 }
 
+TEST_F(IRBinaryRoundtripTest, Switch) {
+    auto* x = b.FunctionParam<i32>("x");
+    auto* fn = b.Function("Function", ty.i32());
+    fn->SetParams({x});
+    b.Append(fn->Block(), [&] {
+        auto* switch_ = b.Switch(x);
+        b.Append(b.Case(switch_, {b.Constant(1_i)}), [&] { b.Return(fn, 1_i); });
+        b.Append(b.Case(switch_, {b.Constant(2_i), b.Constant(3_i)}), [&] { b.Return(fn, 2_i); });
+        b.Append(b.Case(switch_, {nullptr}), [&] { b.Return(fn, 3_i); });
+    });
+    RUN_TEST();
+}
+
+TEST_F(IRBinaryRoundtripTest, SwitchResults) {
+    auto* x = b.FunctionParam<i32>("x");
+    auto* fn = b.Function("Function", ty.i32());
+    fn->SetParams({x});
+    b.Append(fn->Block(), [&] {
+        auto* switch_ = b.Switch(x);
+        auto* res = b.InstructionResult<i32>();
+        switch_->SetResults(Vector{res});
+        b.Append(b.Case(switch_, {b.Constant(1_i)}), [&] { b.ExitSwitch(switch_, 1_i); });
+        b.Append(b.Case(switch_, {b.Constant(2_i), b.Constant(3_i)}),
+                 [&] { b.ExitSwitch(switch_, 2_i); });
+        b.Append(b.Case(switch_, {nullptr}), [&] { b.ExitSwitch(switch_, 3_i); });
+        b.Return(fn, res);
+    });
+    RUN_TEST();
+}
+
 }  // namespace
 }  // namespace tint::core::ir::binary
diff --git a/src/tint/lang/core/ir/exit_switch.cc b/src/tint/lang/core/ir/exit_switch.cc
index 5d90811..0ac5cc2 100644
--- a/src/tint/lang/core/ir/exit_switch.cc
+++ b/src/tint/lang/core/ir/exit_switch.cc
@@ -38,6 +38,8 @@
 
 namespace tint::core::ir {
 
+ExitSwitch::ExitSwitch() = default;
+
 ExitSwitch::ExitSwitch(ir::Switch* sw, VectorRef<Value*> args /* = tint::Empty */) {
     SetSwitch(sw);
     AddOperands(ExitSwitch::kArgsOperandOffset, std::move(args));
diff --git a/src/tint/lang/core/ir/exit_switch.h b/src/tint/lang/core/ir/exit_switch.h
index 878f73b..d7d3032 100644
--- a/src/tint/lang/core/ir/exit_switch.h
+++ b/src/tint/lang/core/ir/exit_switch.h
@@ -46,6 +46,9 @@
     /// The base offset in Operands() for the args
     static constexpr size_t kArgsOperandOffset = 0;
 
+    /// Constructor (no operands, no switch)
+    ExitSwitch();
+
     /// Constructor
     /// @param sw the switch being exited
     /// @param args the target MultiInBlock arguments
diff --git a/src/tint/lang/core/ir/switch.cc b/src/tint/lang/core/ir/switch.cc
index 30e08bf..1b0d2bc 100644
--- a/src/tint/lang/core/ir/switch.cc
+++ b/src/tint/lang/core/ir/switch.cc
@@ -37,6 +37,8 @@
 
 namespace tint::core::ir {
 
+Switch::Switch() = default;
+
 Switch::Switch(Value* cond) {
     TINT_ASSERT(cond);
 
diff --git a/src/tint/lang/core/ir/switch.h b/src/tint/lang/core/ir/switch.h
index 3b44161..d0783b8 100644
--- a/src/tint/lang/core/ir/switch.h
+++ b/src/tint/lang/core/ir/switch.h
@@ -81,6 +81,9 @@
         ConstPropagatingPtr<ir::Block> block;
     };
 
+    /// Constructor (no results, no operands, no cases)
+    Switch();
+
     /// Constructor
     /// @param cond the condition
     explicit Switch(Value* cond);