tint: Move Switch() to own header

castable.h is bigger than it needs to be, and pretty much every tint .cc file includes castable.h
Reduce the amount of code that .cc files that don't use Switch() need to compile.

Change-Id: Ibb4e8b0bc7104ad33a7f2f39587c7d9e749fee97
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/123401
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index 4c02044..a9871da 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -213,6 +213,7 @@
     "scope_stack.h",
     "source.cc",
     "source.h",
+    "switch.h",
     "symbol.cc",
     "symbol.h",
     "symbol_table.cc",
@@ -1982,6 +1983,7 @@
       "reflection_test.cc",
       "scope_stack_test.cc",
       "source_test.cc",
+      "switch_test.cc",
       "symbol_table_test.cc",
       "symbol_test.cc",
       "traits_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index bbaaa9f..91f8178 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -280,6 +280,7 @@
   resolver/validator.cc
   resolver/validator.h
   scope_stack.h
+  switch.h
   sem/array_count.cc
   sem/array_count.h
   sem/behavior.cc
@@ -960,6 +961,7 @@
     sem/struct_test.cc
     sem/value_expression_test.cc
     source_test.cc
+    switch_test.cc
     symbol_table_test.cc
     symbol_test.cc
     test_main.cc
@@ -1455,7 +1457,7 @@
   endif()
 
   list(APPEND TINT_BENCHMARK_SRCS
-    "castable_bench.cc"
+    "switch_bench.cc"
     "bench/benchmark.cc"
     "reader/wgsl/parser_bench.cc"
   )
diff --git a/src/tint/ast/module.cc b/src/tint/ast/module.cc
index 32dafe8..fa652de 100644
--- a/src/tint/ast/module.cc
+++ b/src/tint/ast/module.cc
@@ -18,6 +18,7 @@
 
 #include "src/tint/ast/type_decl.h"
 #include "src/tint/program_builder.h"
+#include "src/tint/switch.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::ast::Module);
 
diff --git a/src/tint/ast/traverse_expressions.h b/src/tint/ast/traverse_expressions.h
index fb8aad1..84a10c3 100644
--- a/src/tint/ast/traverse_expressions.h
+++ b/src/tint/ast/traverse_expressions.h
@@ -25,6 +25,7 @@
 #include "src/tint/ast/member_accessor_expression.h"
 #include "src/tint/ast/phony_expression.h"
 #include "src/tint/ast/unary_op_expression.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/compiler_macros.h"
 #include "src/tint/utils/reverse.h"
 #include "src/tint/utils/vector.h"
diff --git a/src/tint/castable.h b/src/tint/castable.h
index d6597fc..acb9a18 100644
--- a/src/tint/castable.h
+++ b/src/tint/castable.h
@@ -532,280 +532,6 @@
 template <typename... TYPES>
 using CastableCommonBase = detail::CastableCommonBase<TYPES...>;
 
-/// Default can be used as the default case for a Switch(), when all previous cases failed to match.
-///
-/// Example:
-/// ```
-/// Switch(object,
-///     [&](TypeA*) { /* ... */ },
-///     [&](TypeB*) { /* ... */ },
-///     [&](Default) { /* If not TypeA or TypeB */ });
-/// ```
-struct Default {};
-
-namespace detail {
-
-/// Evaluates to the Switch case type being matched by the switch case function `FN`.
-/// @note does not handle the Default case
-/// @see Switch().
-template <typename FN>
-using SwitchCaseType = std::remove_pointer_t<traits::ParameterType<std::remove_reference_t<FN>, 0>>;
-
-/// Evaluates to true if the function `FN` has the signature of a Default case in a Switch().
-/// @see Switch().
-template <typename FN>
-inline constexpr bool IsDefaultCase =
-    std::is_same_v<traits::ParameterType<std::remove_reference_t<FN>, 0>, Default>;
-
-/// Searches the list of Switch cases for a Default case, returning the index of the Default case.
-/// If the a Default case is not found in the tuple, then -1 is returned.
-template <typename TUPLE, std::size_t START_IDX = 0>
-constexpr int IndexOfDefaultCase() {
-    if constexpr (START_IDX < std::tuple_size_v<TUPLE>) {
-        return IsDefaultCase<std::tuple_element_t<START_IDX, TUPLE>>
-                   ? static_cast<int>(START_IDX)
-                   : IndexOfDefaultCase<TUPLE, START_IDX + 1>();
-    } else {
-        return -1;
-    }
-}
-
-/// The implementation of Switch() for non-Default cases.
-/// Switch splits the cases into two a low and high block of cases, and quickly rules out blocks
-/// that cannot match by comparing the HashCode of the object and the cases in the block. If a block
-/// of cases may match the given object's type, then that block is split into two, and the process
-/// recurses. When NonDefaultCases() is called with a single case, then As<> will be used to
-/// dynamically cast to the case type and if the cast succeeds, then the case handler is called.
-/// @returns true if a case handler was found, otherwise false.
-template <typename T, typename RETURN_TYPE, typename... CASES>
-inline bool NonDefaultCases([[maybe_unused]] T* object,
-                            const TypeInfo* type,
-                            [[maybe_unused]] RETURN_TYPE* result,
-                            std::tuple<CASES...>&& cases) {
-    using Cases = std::tuple<CASES...>;
-
-    static constexpr bool kHasReturnType = !std::is_same_v<RETURN_TYPE, void>;
-    static constexpr size_t kNumCases = sizeof...(CASES);
-
-    if constexpr (kNumCases == 0) {
-        // No cases. Nothing to do.
-        return false;
-    } else if constexpr (kNumCases == 1) {  // NOLINT: cpplint doesn't understand
-                                            // `else if constexpr`
-        // Single case.
-        using CaseFunc = std::tuple_element_t<0, Cases>;
-        static_assert(!IsDefaultCase<CaseFunc>, "NonDefaultCases called with a Default case");
-        // Attempt to dynamically cast the object to the handler type. If that succeeds, call the
-        // case handler with the cast object.
-        using CaseType = SwitchCaseType<CaseFunc>;
-        if (type->Is<CaseType>()) {
-            auto* ptr = static_cast<CaseType*>(object);
-            if constexpr (kHasReturnType) {
-                new (result) RETURN_TYPE(static_cast<RETURN_TYPE>(std::get<0>(cases)(ptr)));
-            } else {
-                std::get<0>(cases)(ptr);
-            }
-            return true;
-        }
-        return false;
-    } else {
-        // Multiple cases.
-        // Check the hashcode bits to see if there's any possibility of a case matching in these
-        // cases. If there isn't, we can skip all these cases.
-        if (MaybeAnyOf(TypeInfo::CombinedHashCodeOf<SwitchCaseType<CASES>...>(),
-                       type->full_hashcode)) {
-            // Split the cases into two, and recurse.
-            constexpr size_t kMid = kNumCases / 2;
-            return NonDefaultCases(object, type, result, traits::Slice<0, kMid>(cases)) ||
-                   NonDefaultCases(object, type, result,
-                                   traits::Slice<kMid, kNumCases - kMid>(cases));
-        } else {
-            return false;
-        }
-    }
-}
-
-/// The implementation of Switch() for all cases.
-/// @see NonDefaultCases
-template <typename T, typename RETURN_TYPE, typename... CASES>
-inline void SwitchCases(T* object, RETURN_TYPE* result, std::tuple<CASES...>&& cases) {
-    using Cases = std::tuple<CASES...>;
-
-    static constexpr int kDefaultIndex = detail::IndexOfDefaultCase<Cases>();
-    static constexpr bool kHasDefaultCase = kDefaultIndex >= 0;
-    static constexpr bool kHasReturnType = !std::is_same_v<RETURN_TYPE, void>;
-
-    // Static assertions
-    static constexpr bool kDefaultIsOK =
-        kDefaultIndex == -1 || kDefaultIndex == static_cast<int>(std::tuple_size_v<Cases> - 1);
-    static constexpr bool kReturnIsOK =
-        kHasDefaultCase || !kHasReturnType || std::is_constructible_v<RETURN_TYPE>;
-    static_assert(kDefaultIsOK, "Default case must be last in Switch()");
-    static_assert(kReturnIsOK,
-                  "Switch() requires either a Default case or a return type that is either void or "
-                  "default-constructable");
-
-    // If the static asserts have fired, don't bother spewing more errors below
-    static constexpr bool kAllOK = kDefaultIsOK && kReturnIsOK;
-    if constexpr (kAllOK) {
-        if (object) {
-            auto* type = &object->TypeInfo();
-            if constexpr (kHasDefaultCase) {
-                // Evaluate non-default cases.
-                if (!detail::NonDefaultCases<T>(object, type, result,
-                                                traits::Slice<0, kDefaultIndex>(cases))) {
-                    // Nothing matched. Evaluate default case.
-                    if constexpr (kHasReturnType) {
-                        new (result) RETURN_TYPE(
-                            static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})));
-                    } else {
-                        std::get<kDefaultIndex>(cases)({});
-                    }
-                }
-            } else {
-                if (!detail::NonDefaultCases<T>(object, type, result, std::move(cases))) {
-                    // Nothing matched. No default case.
-                    if constexpr (kHasReturnType) {
-                        new (result) RETURN_TYPE();
-                    }
-                }
-            }
-        } else {
-            // Object is nullptr, so no cases can match
-            if constexpr (kHasDefaultCase) {
-                // Evaluate default case.
-                if constexpr (kHasReturnType) {
-                    new (result)
-                        RETURN_TYPE(static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})));
-                } else {
-                    std::get<kDefaultIndex>(cases)({});
-                }
-            } else {
-                // No default case, no case can match.
-                if constexpr (kHasReturnType) {
-                    new (result) RETURN_TYPE();
-                }
-            }
-        }
-    }
-}
-
-/// Resolves to T if T is not nullptr_t, otherwise resolves to Ignore.
-template <typename T>
-using NullptrToIgnore = std::conditional_t<std::is_same_v<T, std::nullptr_t>, Ignore, T>;
-
-/// Resolves to `const TYPE` if any of `CASE_RETURN_TYPES` are const or pointer-to-const, otherwise
-/// resolves to TYPE.
-template <typename TYPE, typename... CASE_RETURN_TYPES>
-using PropagateReturnConst = std::conditional_t<
-    // Are any of the pointer-stripped types const?
-    (std::is_const_v<std::remove_pointer_t<CASE_RETURN_TYPES>> || ...),
-    const TYPE,  // Yes: Apply const to TYPE
-    TYPE>;       // No:  Passthrough
-
-/// SwitchReturnTypeImpl is the implementation of SwitchReturnType
-template <bool IS_CASTABLE, typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
-struct SwitchReturnTypeImpl;
-
-/// SwitchReturnTypeImpl specialization for non-castable case types and an explicitly specified
-/// return type.
-template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
-struct SwitchReturnTypeImpl</*IS_CASTABLE*/ false, REQUESTED_TYPE, CASE_RETURN_TYPES...> {
-    /// Resolves to `REQUESTED_TYPE`
-    using type = REQUESTED_TYPE;
-};
-
-/// SwitchReturnTypeImpl specialization for non-castable case types and an inferred return type.
-template <typename... CASE_RETURN_TYPES>
-struct SwitchReturnTypeImpl</*IS_CASTABLE*/ false, Infer, CASE_RETURN_TYPES...> {
-    /// Resolves to the common type for all the cases return types.
-    using type = std::common_type_t<CASE_RETURN_TYPES...>;
-};
-
-/// SwitchReturnTypeImpl specialization for castable case types and an explicitly specified return
-/// type.
-template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
-struct SwitchReturnTypeImpl</*IS_CASTABLE*/ true, REQUESTED_TYPE, CASE_RETURN_TYPES...> {
-  public:
-    /// Resolves to `const REQUESTED_TYPE*` or `REQUESTED_TYPE*`
-    using type = PropagateReturnConst<std::remove_pointer_t<REQUESTED_TYPE>, CASE_RETURN_TYPES...>*;
-};
-
-/// SwitchReturnTypeImpl specialization for castable case types and an inferred return type.
-template <typename... CASE_RETURN_TYPES>
-struct SwitchReturnTypeImpl</*IS_CASTABLE*/ true, Infer, CASE_RETURN_TYPES...> {
-  private:
-    using InferredType =
-        CastableCommonBase<detail::NullptrToIgnore<std::remove_pointer_t<CASE_RETURN_TYPES>>...>;
-
-  public:
-    /// `const T*` or `T*`, where T is the common base type for all the castable case types.
-    using type = PropagateReturnConst<InferredType, CASE_RETURN_TYPES...>*;
-};
-
-/// Resolves to the return type for a Switch() with the requested return type `REQUESTED_TYPE` and
-/// case statement return types. If `REQUESTED_TYPE` is Infer then the return type will be inferred
-/// from the case return types.
-template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
-using SwitchReturnType = typename SwitchReturnTypeImpl<
-    IsCastable<NullptrToIgnore<std::remove_pointer_t<CASE_RETURN_TYPES>>...>,
-    REQUESTED_TYPE,
-    CASE_RETURN_TYPES...>::type;
-
-}  // namespace detail
-
-/// Switch is used to dispatch one of the provided callback case handler functions based on the type
-/// of `object` and the parameter type of the case handlers. Switch will sequentially check the type
-/// of `object` against each of the switch case handler functions, and will invoke the first case
-/// handler function which has a parameter type that matches the object type. When a case handler is
-/// matched, it will be called with the single argument of `object` cast to the case handler's
-/// parameter type. Switch will invoke at most one case handler. Each of the case functions must
-/// have the signature `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R`
-/// is the return type, consistent across all case handlers.
-///
-/// An optional default case function with the signature `R(Default)` can be used as the last case.
-/// This default case will be called if all previous cases failed to match.
-///
-/// If `object` is nullptr and a default case is provided, then the default case will be called. If
-/// `object` is nullptr and no default case is provided, then no cases will be called.
-///
-/// Example:
-/// ```
-/// Switch(object,
-///     [&](TypeA*) { /* ... */ },
-///     [&](TypeB*) { /* ... */ });
-///
-/// Switch(object,
-///     [&](TypeA*) { /* ... */ },
-///     [&](TypeB*) { /* ... */ },
-///     [&](Default) { /* Called if object is not TypeA or TypeB */ });
-/// ```
-///
-/// @param object the object who's type is used to
-/// @param cases the switch cases
-/// @return the value returned by the called case. If no cases matched, then the zero value for the
-/// consistent case type.
-template <typename RETURN_TYPE = detail::Infer, typename T = CastableBase, typename... CASES>
-inline auto Switch(T* object, CASES&&... cases) {
-    using ReturnType = detail::SwitchReturnType<RETURN_TYPE, traits::ReturnType<CASES>...>;
-    static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
-
-    if constexpr (kHasReturnType) {
-        // Replacement for std::aligned_storage as this is broken on earlier versions of MSVC.
-        struct alignas(alignof(ReturnType)) ReturnStorage {
-            uint8_t data[sizeof(ReturnType)];
-        };
-        ReturnStorage storage;
-        auto* res = utils::Bitcast<ReturnType*>(&storage);
-        TINT_DEFER(res->~ReturnType());
-        detail::SwitchCases(object, res, std::forward_as_tuple(std::forward<CASES>(cases)...));
-        return *res;
-    } else {
-        detail::SwitchCases<T, void>(object, nullptr,
-                                     std::forward_as_tuple(std::forward<CASES>(cases)...));
-    }
-}
-
 }  // namespace tint
 
 TINT_CASTABLE_POP_DISABLE_WARNINGS();
diff --git a/src/tint/castable_test.cc b/src/tint/castable_test.cc
index c452c9b..57bd904 100644
--- a/src/tint/castable_test.cc
+++ b/src/tint/castable_test.cc
@@ -20,6 +20,7 @@
 #include "gtest/gtest.h"
 
 namespace tint {
+namespace {
 
 struct Animal : public tint::Castable<Animal> {};
 struct Amphibian : public tint::Castable<Amphibian, Animal> {};
@@ -31,8 +32,6 @@
 struct Gecko : public tint::Castable<Gecko, Lizard> {};
 struct Iguana : public tint::Castable<Iguana, Lizard> {};
 
-namespace {
-
 TEST(CastableBase, Is) {
     std::unique_ptr<CastableBase> frog = std::make_unique<Frog>();
     std::unique_ptr<CastableBase> bear = std::make_unique<Bear>();
@@ -230,512 +229,6 @@
     ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get()));
 }
 
-TEST(Castable, SwitchNoDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        bool frog_matched_amphibian = false;
-        Switch(
-            frog.get(),  //
-            [&](Reptile*) { FAIL() << "frog is not reptile"; },
-            [&](Mammal*) { FAIL() << "frog is not mammal"; },
-            [&](Amphibian* amphibian) {
-                EXPECT_EQ(amphibian, frog.get());
-                frog_matched_amphibian = true;
-            });
-        EXPECT_TRUE(frog_matched_amphibian);
-    }
-    {
-        bool bear_matched_mammal = false;
-        Switch(
-            bear.get(),  //
-            [&](Reptile*) { FAIL() << "bear is not reptile"; },
-            [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
-            [&](Mammal* mammal) {
-                EXPECT_EQ(mammal, bear.get());
-                bear_matched_mammal = true;
-            });
-        EXPECT_TRUE(bear_matched_mammal);
-    }
-    {
-        bool gecko_matched_reptile = false;
-        Switch(
-            gecko.get(),  //
-            [&](Mammal*) { FAIL() << "gecko is not mammal"; },
-            [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
-            [&](Reptile* reptile) {
-                EXPECT_EQ(reptile, gecko.get());
-                gecko_matched_reptile = true;
-            });
-        EXPECT_TRUE(gecko_matched_reptile);
-    }
-}
-
-TEST(Castable, SwitchWithUnusedDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        bool frog_matched_amphibian = false;
-        Switch(
-            frog.get(),  //
-            [&](Reptile*) { FAIL() << "frog is not reptile"; },
-            [&](Mammal*) { FAIL() << "frog is not mammal"; },
-            [&](Amphibian* amphibian) {
-                EXPECT_EQ(amphibian, frog.get());
-                frog_matched_amphibian = true;
-            },
-            [&](Default) { FAIL() << "default should not have been selected"; });
-        EXPECT_TRUE(frog_matched_amphibian);
-    }
-    {
-        bool bear_matched_mammal = false;
-        Switch(
-            bear.get(),  //
-            [&](Reptile*) { FAIL() << "bear is not reptile"; },
-            [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
-            [&](Mammal* mammal) {
-                EXPECT_EQ(mammal, bear.get());
-                bear_matched_mammal = true;
-            },
-            [&](Default) { FAIL() << "default should not have been selected"; });
-        EXPECT_TRUE(bear_matched_mammal);
-    }
-    {
-        bool gecko_matched_reptile = false;
-        Switch(
-            gecko.get(),  //
-            [&](Mammal*) { FAIL() << "gecko is not mammal"; },
-            [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
-            [&](Reptile* reptile) {
-                EXPECT_EQ(reptile, gecko.get());
-                gecko_matched_reptile = true;
-            },
-            [&](Default) { FAIL() << "default should not have been selected"; });
-        EXPECT_TRUE(gecko_matched_reptile);
-    }
-}
-
-TEST(Castable, SwitchDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        bool frog_matched_default = false;
-        Switch(
-            frog.get(),  //
-            [&](Reptile*) { FAIL() << "frog is not reptile"; },
-            [&](Mammal*) { FAIL() << "frog is not mammal"; },
-            [&](Default) { frog_matched_default = true; });
-        EXPECT_TRUE(frog_matched_default);
-    }
-    {
-        bool bear_matched_default = false;
-        Switch(
-            bear.get(),  //
-            [&](Reptile*) { FAIL() << "bear is not reptile"; },
-            [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
-            [&](Default) { bear_matched_default = true; });
-        EXPECT_TRUE(bear_matched_default);
-    }
-    {
-        bool gecko_matched_default = false;
-        Switch(
-            gecko.get(),  //
-            [&](Mammal*) { FAIL() << "gecko is not mammal"; },
-            [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
-            [&](Default) { gecko_matched_default = true; });
-        EXPECT_TRUE(gecko_matched_default);
-    }
-}
-
-TEST(Castable, SwitchMatchFirst) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    {
-        bool frog_matched_animal = false;
-        Switch(
-            frog.get(),
-            [&](Animal* animal) {
-                EXPECT_EQ(animal, frog.get());
-                frog_matched_animal = true;
-            },
-            [&](Amphibian*) { FAIL() << "animal should have been matched first"; });
-        EXPECT_TRUE(frog_matched_animal);
-    }
-    {
-        bool frog_matched_amphibian = false;
-        Switch(
-            frog.get(),
-            [&](Amphibian* amphibain) {
-                EXPECT_EQ(amphibain, frog.get());
-                frog_matched_amphibian = true;
-            },
-            [&](Animal*) { FAIL() << "amphibian should have been matched first"; });
-        EXPECT_TRUE(frog_matched_amphibian);
-    }
-}
-
-TEST(Castable, SwitchReturnValueWithDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        const char* result = Switch(
-            frog.get(),                              //
-            [](Mammal*) { return "mammal"; },        //
-            [](Amphibian*) { return "amphibian"; },  //
-            [](Default) { return "unknown"; });
-        static_assert(std::is_same_v<decltype(result), const char*>);
-        EXPECT_EQ(std::string(result), "amphibian");
-    }
-    {
-        const char* result = Switch(
-            bear.get(),                              //
-            [](Mammal*) { return "mammal"; },        //
-            [](Amphibian*) { return "amphibian"; },  //
-            [](Default) { return "unknown"; });
-        static_assert(std::is_same_v<decltype(result), const char*>);
-        EXPECT_EQ(std::string(result), "mammal");
-    }
-    {
-        const char* result = Switch(
-            gecko.get(),                             //
-            [](Mammal*) { return "mammal"; },        //
-            [](Amphibian*) { return "amphibian"; },  //
-            [](Default) { return "unknown"; });
-        static_assert(std::is_same_v<decltype(result), const char*>);
-        EXPECT_EQ(std::string(result), "unknown");
-    }
-}
-
-TEST(Castable, SwitchReturnValueWithoutDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        const char* result = Switch(
-            frog.get(),                        //
-            [](Mammal*) { return "mammal"; },  //
-            [](Amphibian*) { return "amphibian"; });
-        static_assert(std::is_same_v<decltype(result), const char*>);
-        EXPECT_EQ(std::string(result), "amphibian");
-    }
-    {
-        const char* result = Switch(
-            bear.get(),                        //
-            [](Mammal*) { return "mammal"; },  //
-            [](Amphibian*) { return "amphibian"; });
-        static_assert(std::is_same_v<decltype(result), const char*>);
-        EXPECT_EQ(std::string(result), "mammal");
-    }
-    {
-        auto* result = Switch(
-            gecko.get(),                       //
-            [](Mammal*) { return "mammal"; },  //
-            [](Amphibian*) { return "amphibian"; });
-        static_assert(std::is_same_v<decltype(result), const char*>);
-        EXPECT_EQ(result, nullptr);
-    }
-}
-
-TEST(Castable, SwitchInferPODReturnTypeWithDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        auto result = Switch(
-            frog.get(),                       //
-            [](Mammal*) { return 1; },        //
-            [](Amphibian*) { return 2.0f; },  //
-            [](Default) { return 3.0; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 2.0);
-    }
-    {
-        auto result = Switch(
-            bear.get(),                       //
-            [](Mammal*) { return 1.0; },      //
-            [](Amphibian*) { return 2.0f; },  //
-            [](Default) { return 3; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 1.0);
-    }
-    {
-        auto result = Switch(
-            gecko.get(),                   //
-            [](Mammal*) { return 1.0f; },  //
-            [](Amphibian*) { return 2; },  //
-            [](Default) { return 3.0; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 3.0);
-    }
-}
-
-TEST(Castable, SwitchInferPODReturnTypeWithoutDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        auto result = Switch(
-            frog.get(),                 //
-            [](Mammal*) { return 1; },  //
-            [](Amphibian*) { return 2.0f; });
-        static_assert(std::is_same_v<decltype(result), float>);
-        EXPECT_EQ(result, 2.0f);
-    }
-    {
-        auto result = Switch(
-            bear.get(),                    //
-            [](Mammal*) { return 1.0f; },  //
-            [](Amphibian*) { return 2; });
-        static_assert(std::is_same_v<decltype(result), float>);
-        EXPECT_EQ(result, 1.0f);
-    }
-    {
-        auto result = Switch(
-            gecko.get(),                  //
-            [](Mammal*) { return 1.0; },  //
-            [](Amphibian*) { return 2.0f; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 0.0);
-    }
-}
-
-TEST(Castable, SwitchInferCastableReturnTypeWithDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        auto* result = Switch(
-            frog.get(),                          //
-            [](Mammal* p) { return p; },         //
-            [](Amphibian*) { return nullptr; },  //
-            [](Default) { return nullptr; });
-        static_assert(std::is_same_v<decltype(result), Mammal*>);
-        EXPECT_EQ(result, nullptr);
-    }
-    {
-        auto* result = Switch(
-            bear.get(),                   //
-            [](Mammal* p) { return p; },  //
-            [](Amphibian* p) { return const_cast<const Amphibian*>(p); },
-            [](Default) { return nullptr; });
-        static_assert(std::is_same_v<decltype(result), const Animal*>);
-        EXPECT_EQ(result, bear.get());
-    }
-    {
-        auto* result = Switch(
-            gecko.get(),                     //
-            [](Mammal* p) { return p; },     //
-            [](Amphibian* p) { return p; },  //
-            [](Default) -> CastableBase* { return nullptr; });
-        static_assert(std::is_same_v<decltype(result), CastableBase*>);
-        EXPECT_EQ(result, nullptr);
-    }
-}
-
-TEST(Castable, SwitchInferCastableReturnTypeWithoutDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        auto* result = Switch(
-            frog.get(),                   //
-            [](Mammal* p) { return p; },  //
-            [](Amphibian*) { return nullptr; });
-        static_assert(std::is_same_v<decltype(result), Mammal*>);
-        EXPECT_EQ(result, nullptr);
-    }
-    {
-        auto* result = Switch(
-            bear.get(),                                                     //
-            [](Mammal* p) { return p; },                                    //
-            [](Amphibian* p) { return const_cast<const Amphibian*>(p); });  //
-        static_assert(std::is_same_v<decltype(result), const Animal*>);
-        EXPECT_EQ(result, bear.get());
-    }
-    {
-        auto* result = Switch(
-            gecko.get(),                  //
-            [](Mammal* p) { return p; },  //
-            [](Amphibian* p) { return p; });
-        static_assert(std::is_same_v<decltype(result), Animal*>);
-        EXPECT_EQ(result, nullptr);
-    }
-}
-
-TEST(Castable, SwitchExplicitPODReturnTypeWithDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        auto result = Switch<double>(
-            frog.get(),                       //
-            [](Mammal*) { return 1; },        //
-            [](Amphibian*) { return 2.0f; },  //
-            [](Default) { return 3.0; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 2.0f);
-    }
-    {
-        auto result = Switch<double>(
-            bear.get(),                    //
-            [](Mammal*) { return 1; },     //
-            [](Amphibian*) { return 2; },  //
-            [](Default) { return 3; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 1.0f);
-    }
-    {
-        auto result = Switch<double>(
-            gecko.get(),                      //
-            [](Mammal*) { return 1.0f; },     //
-            [](Amphibian*) { return 2.0f; },  //
-            [](Default) { return 3.0f; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 3.0f);
-    }
-}
-
-TEST(Castable, SwitchExplicitPODReturnTypeWithoutDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        auto result = Switch<double>(
-            frog.get(),                 //
-            [](Mammal*) { return 1; },  //
-            [](Amphibian*) { return 2.0f; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 2.0f);
-    }
-    {
-        auto result = Switch<double>(
-            bear.get(),                    //
-            [](Mammal*) { return 1.0f; },  //
-            [](Amphibian*) { return 2; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 1.0f);
-    }
-    {
-        auto result = Switch<double>(
-            gecko.get(),                  //
-            [](Mammal*) { return 1.0; },  //
-            [](Amphibian*) { return 2.0f; });
-        static_assert(std::is_same_v<decltype(result), double>);
-        EXPECT_EQ(result, 0.0);
-    }
-}
-
-TEST(Castable, SwitchExplicitCastableReturnTypeWithDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        auto* result = Switch<Animal>(
-            frog.get(),                          //
-            [](Mammal* p) { return p; },         //
-            [](Amphibian*) { return nullptr; },  //
-            [](Default) { return nullptr; });
-        static_assert(std::is_same_v<decltype(result), Animal*>);
-        EXPECT_EQ(result, nullptr);
-    }
-    {
-        auto* result = Switch<CastableBase>(
-            bear.get(),                   //
-            [](Mammal* p) { return p; },  //
-            [](Amphibian* p) { return const_cast<const Amphibian*>(p); },
-            [](Default) { return nullptr; });
-        static_assert(std::is_same_v<decltype(result), const CastableBase*>);
-        EXPECT_EQ(result, bear.get());
-    }
-    {
-        auto* result = Switch<const Animal>(
-            gecko.get(),                     //
-            [](Mammal* p) { return p; },     //
-            [](Amphibian* p) { return p; },  //
-            [](Default) { return nullptr; });
-        static_assert(std::is_same_v<decltype(result), const Animal*>);
-        EXPECT_EQ(result, nullptr);
-    }
-}
-
-TEST(Castable, SwitchExplicitCastableReturnTypeWithoutDefault) {
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
-    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
-    {
-        auto* result = Switch<Animal>(
-            frog.get(),                   //
-            [](Mammal* p) { return p; },  //
-            [](Amphibian*) { return nullptr; });
-        static_assert(std::is_same_v<decltype(result), Animal*>);
-        EXPECT_EQ(result, nullptr);
-    }
-    {
-        auto* result = Switch<CastableBase>(
-            bear.get(),                                                     //
-            [](Mammal* p) { return p; },                                    //
-            [](Amphibian* p) { return const_cast<const Amphibian*>(p); });  //
-        static_assert(std::is_same_v<decltype(result), const CastableBase*>);
-        EXPECT_EQ(result, bear.get());
-    }
-    {
-        auto* result = Switch<const Animal*>(
-            gecko.get(),                  //
-            [](Mammal* p) { return p; },  //
-            [](Amphibian* p) { return p; });
-        static_assert(std::is_same_v<decltype(result), const Animal*>);
-        EXPECT_EQ(result, nullptr);
-    }
-}
-
-TEST(Castable, SwitchNull) {
-    Animal* null = nullptr;
-    Switch(
-        null,  //
-        [&](Amphibian*) { FAIL() << "should not be called"; },
-        [&](Animal*) { FAIL() << "should not be called"; });
-}
-
-TEST(Castable, SwitchNullNoDefault) {
-    Animal* null = nullptr;
-    bool default_called = false;
-    Switch(
-        null,  //
-        [&](Amphibian*) { FAIL() << "should not be called"; },
-        [&](Animal*) { FAIL() << "should not be called"; },
-        [&](Default) { default_called = true; });
-    EXPECT_TRUE(default_called);
-}
-
-TEST(Castable, SwitchReturnNoDefaultInitializer) {
-    struct Object {
-        explicit Object(int v) : value(v) {}
-        int value;
-    };
-
-    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
-    {
-        auto result = Switch(
-            frog.get(),                            //
-            [](Mammal*) { return Object(1); },     //
-            [](Amphibian*) { return Object(2); },  //
-            [](Default) { return Object(3); });
-        static_assert(std::is_same_v<decltype(result), Object>);
-        EXPECT_EQ(result.value, 2);
-    }
-    {
-        auto result = Switch(
-            frog.get(),                         //
-            [](Mammal*) { return Object(1); },  //
-            [](Default) { return Object(3); });
-        static_assert(std::is_same_v<decltype(result), Object>);
-        EXPECT_EQ(result.value, 3);
-    }
-}
-
 // IsCastable static tests
 static_assert(IsCastable<CastableBase>);
 static_assert(IsCastable<Animal>);
diff --git a/src/tint/constant/value.cc b/src/tint/constant/value.cc
index c41ca34..7545731 100644
--- a/src/tint/constant/value.cc
+++ b/src/tint/constant/value.cc
@@ -14,6 +14,7 @@
 
 #include "src/tint/constant/value.h"
 
+#include "src/tint/switch.h"
 #include "src/tint/type/array.h"
 #include "src/tint/type/matrix.h"
 #include "src/tint/type/struct.h"
diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc
index 6e96523..a480356 100644
--- a/src/tint/inspector/inspector.cc
+++ b/src/tint/inspector/inspector.cc
@@ -39,6 +39,7 @@
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/struct.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/array.h"
 #include "src/tint/type/bool.h"
 #include "src/tint/type/depth_multisampled_texture.h"
diff --git a/src/tint/ir/builder_impl.cc b/src/tint/ir/builder_impl.cc
index 07a1fdb..02a7291 100644
--- a/src/tint/ir/builder_impl.cc
+++ b/src/tint/ir/builder_impl.cc
@@ -62,6 +62,7 @@
 #include "src/tint/sem/value_constructor.h"
 #include "src/tint/sem/value_conversion.h"
 #include "src/tint/sem/value_expression.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/void.h"
 
 namespace tint::ir {
diff --git a/src/tint/ir/constant.cc b/src/tint/ir/constant.cc
index da4cb36..61610b0 100644
--- a/src/tint/ir/constant.cc
+++ b/src/tint/ir/constant.cc
@@ -19,6 +19,7 @@
 #include "src/tint/constant/composite.h"
 #include "src/tint/constant/scalar.h"
 #include "src/tint/constant/splat.h"
+#include "src/tint/switch.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::ir::Constant);
 
diff --git a/src/tint/ir/debug.cc b/src/tint/ir/debug.cc
index 290aaeb..eb7d40f 100644
--- a/src/tint/ir/debug.cc
+++ b/src/tint/ir/debug.cc
@@ -22,6 +22,7 @@
 #include "src/tint/ir/loop.h"
 #include "src/tint/ir/switch.h"
 #include "src/tint/ir/terminator.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/string_stream.h"
 
 namespace tint::ir {
diff --git a/src/tint/ir/disassembler.cc b/src/tint/ir/disassembler.cc
index 1bac411..041495d 100644
--- a/src/tint/ir/disassembler.cc
+++ b/src/tint/ir/disassembler.cc
@@ -19,6 +19,7 @@
 #include "src/tint/ir/loop.h"
 #include "src/tint/ir/switch.h"
 #include "src/tint/ir/terminator.h"
+#include "src/tint/switch.h"
 
 namespace tint::ir {
 namespace {
diff --git a/src/tint/program.cc b/src/tint/program.cc
index 432758b..643999d 100644
--- a/src/tint/program.cc
+++ b/src/tint/program.cc
@@ -19,6 +19,7 @@
 #include "src/tint/resolver/resolver.h"
 #include "src/tint/sem/type_expression.h"
 #include "src/tint/sem/value_expression.h"
+#include "src/tint/switch.h"
 
 namespace tint {
 namespace {
diff --git a/src/tint/program_builder.cc b/src/tint/program_builder.cc
index c0c2ef6..99fce9e 100644
--- a/src/tint/program_builder.cc
+++ b/src/tint/program_builder.cc
@@ -21,6 +21,7 @@
 #include "src/tint/sem/type_expression.h"
 #include "src/tint/sem/value_expression.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/compiler_macros.h"
 
 using namespace tint::number_suffixes;  // NOLINT
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index 873ffe6..9318277 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -33,6 +33,7 @@
 #include "src/tint/ast/variable_decl_statement.h"
 #include "src/tint/builtin/builtin_value.h"
 #include "src/tint/builtin/function.h"
+#include "src/tint/switch.h"
 #include "src/tint/transform/spirv_atomic.h"
 #include "src/tint/type/depth_texture.h"
 #include "src/tint/type/sampled_texture.h"
diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc
index a1c9042..9f5afdb 100644
--- a/src/tint/reader/spirv/parser_impl.cc
+++ b/src/tint/reader/spirv/parser_impl.cc
@@ -26,6 +26,7 @@
 #include "src/tint/ast/interpolate_attribute.h"
 #include "src/tint/ast/unary_op_expression.h"
 #include "src/tint/reader/spirv/function.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/depth_texture.h"
 #include "src/tint/type/multisampled_texture.h"
 #include "src/tint/type/sampled_texture.h"
diff --git a/src/tint/reader/spirv/parser_impl_test_helper.cc b/src/tint/reader/spirv/parser_impl_test_helper.cc
index 754922c..2a19c3d 100644
--- a/src/tint/reader/spirv/parser_impl_test_helper.cc
+++ b/src/tint/reader/spirv/parser_impl_test_helper.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "src/tint/reader/spirv/parser_impl_test_helper.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/string_stream.h"
 #include "src/tint/writer/wgsl/generator_impl.h"
 
diff --git a/src/tint/reader/spirv/parser_type.cc b/src/tint/reader/spirv/parser_type.cc
index 71482ff..bc8bd93 100644
--- a/src/tint/reader/spirv/parser_type.cc
+++ b/src/tint/reader/spirv/parser_type.cc
@@ -19,6 +19,7 @@
 #include <utility>
 
 #include "src/tint/program_builder.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/texture_dimension.h"
 #include "src/tint/utils/hash.h"
 #include "src/tint/utils/map.h"
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 325ccb0..0ce4f3b 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -30,6 +30,7 @@
 #include "src/tint/program_builder.h"
 #include "src/tint/sem/member_accessor_expression.h"
 #include "src/tint/sem/value_constructor.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/abstract_float.h"
 #include "src/tint/type/abstract_int.h"
 #include "src/tint/type/array.h"
diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h
index de25b20..8cbea04 100644
--- a/src/tint/resolver/const_eval_test.h
+++ b/src/tint/resolver/const_eval_test.h
@@ -23,6 +23,7 @@
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/test_helper.h"
 #include "src/tint/utils/string_stream.h"
 
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index 67eb430..47e691c 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -60,6 +60,7 @@
 #include "src/tint/builtin/builtin_value.h"
 #include "src/tint/scope_stack.h"
 #include "src/tint/sem/builtin.h"
+#include "src/tint/switch.h"
 #include "src/tint/symbol_table.h"
 #include "src/tint/utils/block_allocator.h"
 #include "src/tint/utils/compiler_macros.h"
diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc
index 8bd12e2..18e2f81 100644
--- a/src/tint/resolver/intrinsic_table.cc
+++ b/src/tint/resolver/intrinsic_table.cc
@@ -24,6 +24,7 @@
 #include "src/tint/sem/pipeline_stage_set.h"
 #include "src/tint/sem/value_constructor.h"
 #include "src/tint/sem/value_conversion.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/abstract_float.h"
 #include "src/tint/type/abstract_int.h"
 #include "src/tint/type/abstract_numeric.h"
diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc
index 115c4b8..6fa2010 100644
--- a/src/tint/resolver/materialize_test.cc
+++ b/src/tint/resolver/materialize_test.cc
@@ -16,6 +16,7 @@
 
 #include "src/tint/resolver/resolver.h"
 #include "src/tint/resolver/resolver_test_helper.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/test_helper.h"
 
 #include "gmock/gmock.h"
diff --git a/src/tint/resolver/sem_helper.cc b/src/tint/resolver/sem_helper.cc
index 45bdf41..36cc4c7 100644
--- a/src/tint/resolver/sem_helper.cc
+++ b/src/tint/resolver/sem_helper.cc
@@ -19,6 +19,7 @@
 #include "src/tint/sem/function_expression.h"
 #include "src/tint/sem/type_expression.h"
 #include "src/tint/sem/value_expression.h"
+#include "src/tint/switch.h"
 
 namespace tint::resolver {
 
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index 4b63523..cf53c4c 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -37,6 +37,7 @@
 #include "src/tint/sem/value_conversion.h"
 #include "src/tint/sem/variable.h"
 #include "src/tint/sem/while_statement.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/block_allocator.h"
 #include "src/tint/utils/map.h"
 #include "src/tint/utils/string_stream.h"
diff --git a/src/tint/sem/info.cc b/src/tint/sem/info.cc
index e124c1c..52a04b5 100644
--- a/src/tint/sem/info.cc
+++ b/src/tint/sem/info.cc
@@ -18,6 +18,7 @@
 #include "src/tint/sem/module.h"
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/value_expression.h"
+#include "src/tint/switch.h"
 
 namespace tint::sem {
 
diff --git a/src/tint/sem/value_expression.cc b/src/tint/sem/value_expression.cc
index 9fe4615..9aacfa3 100644
--- a/src/tint/sem/value_expression.cc
+++ b/src/tint/sem/value_expression.cc
@@ -18,6 +18,7 @@
 
 #include "src/tint/sem/load.h"
 #include "src/tint/sem/materialize.h"
+#include "src/tint/switch.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::sem::ValueExpression);
 
diff --git a/src/tint/switch.h b/src/tint/switch.h
new file mode 100644
index 0000000..9ae8d3b
--- /dev/null
+++ b/src/tint/switch.h
@@ -0,0 +1,305 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_SWITCH_H_
+#define SRC_TINT_SWITCH_H_
+
+#include <tuple>
+#include <utility>
+
+#include "src/tint/castable.h"
+
+namespace tint {
+
+/// Default can be used as the default case for a Switch(), when all previous cases failed to match.
+///
+/// Example:
+/// ```
+/// Switch(object,
+///     [&](TypeA*) { /* ... */ },
+///     [&](TypeB*) { /* ... */ },
+///     [&](Default) { /* If not TypeA or TypeB */ });
+/// ```
+struct Default {};
+
+}  // namespace tint
+
+namespace tint::detail {
+
+/// Evaluates to the Switch case type being matched by the switch case function `FN`.
+/// @note does not handle the Default case
+/// @see Switch().
+template <typename FN>
+using SwitchCaseType = std::remove_pointer_t<traits::ParameterType<std::remove_reference_t<FN>, 0>>;
+
+/// Evaluates to true if the function `FN` has the signature of a Default case in a Switch().
+/// @see Switch().
+template <typename FN>
+inline constexpr bool IsDefaultCase =
+    std::is_same_v<traits::ParameterType<std::remove_reference_t<FN>, 0>, Default>;
+
+/// Searches the list of Switch cases for a Default case, returning the index of the Default case.
+/// If the a Default case is not found in the tuple, then -1 is returned.
+template <typename TUPLE, std::size_t START_IDX = 0>
+constexpr int IndexOfDefaultCase() {
+    if constexpr (START_IDX < std::tuple_size_v<TUPLE>) {
+        return IsDefaultCase<std::tuple_element_t<START_IDX, TUPLE>>
+                   ? static_cast<int>(START_IDX)
+                   : IndexOfDefaultCase<TUPLE, START_IDX + 1>();
+    } else {
+        return -1;
+    }
+}
+
+/// The implementation of Switch() for non-Default cases.
+/// Switch splits the cases into two a low and high block of cases, and quickly rules out blocks
+/// that cannot match by comparing the HashCode of the object and the cases in the block. If a block
+/// of cases may match the given object's type, then that block is split into two, and the process
+/// recurses. When NonDefaultCases() is called with a single case, then As<> will be used to
+/// dynamically cast to the case type and if the cast succeeds, then the case handler is called.
+/// @returns true if a case handler was found, otherwise false.
+template <typename T, typename RETURN_TYPE, typename... CASES>
+inline bool NonDefaultCases([[maybe_unused]] T* object,
+                            const TypeInfo* type,
+                            [[maybe_unused]] RETURN_TYPE* result,
+                            std::tuple<CASES...>&& cases) {
+    using Cases = std::tuple<CASES...>;
+
+    static constexpr bool kHasReturnType = !std::is_same_v<RETURN_TYPE, void>;
+    static constexpr size_t kNumCases = sizeof...(CASES);
+
+    if constexpr (kNumCases == 0) {
+        // No cases. Nothing to do.
+        return false;
+    } else if constexpr (kNumCases == 1) {  // NOLINT: cpplint doesn't understand
+                                            // `else if constexpr`
+        // Single case.
+        using CaseFunc = std::tuple_element_t<0, Cases>;
+        static_assert(!IsDefaultCase<CaseFunc>, "NonDefaultCases called with a Default case");
+        // Attempt to dynamically cast the object to the handler type. If that succeeds, call the
+        // case handler with the cast object.
+        using CaseType = SwitchCaseType<CaseFunc>;
+        if (type->Is<CaseType>()) {
+            auto* ptr = static_cast<CaseType*>(object);
+            if constexpr (kHasReturnType) {
+                new (result) RETURN_TYPE(static_cast<RETURN_TYPE>(std::get<0>(cases)(ptr)));
+            } else {
+                std::get<0>(cases)(ptr);
+            }
+            return true;
+        }
+        return false;
+    } else {
+        // Multiple cases.
+        // Check the hashcode bits to see if there's any possibility of a case matching in these
+        // cases. If there isn't, we can skip all these cases.
+        if (MaybeAnyOf(TypeInfo::CombinedHashCodeOf<SwitchCaseType<CASES>...>(),
+                       type->full_hashcode)) {
+            // Split the cases into two, and recurse.
+            constexpr size_t kMid = kNumCases / 2;
+            return NonDefaultCases(object, type, result, traits::Slice<0, kMid>(cases)) ||
+                   NonDefaultCases(object, type, result,
+                                   traits::Slice<kMid, kNumCases - kMid>(cases));
+        } else {
+            return false;
+        }
+    }
+}
+
+/// The implementation of Switch() for all cases.
+/// @see NonDefaultCases
+template <typename T, typename RETURN_TYPE, typename... CASES>
+inline void SwitchCases(T* object, RETURN_TYPE* result, std::tuple<CASES...>&& cases) {
+    using Cases = std::tuple<CASES...>;
+
+    static constexpr int kDefaultIndex = detail::IndexOfDefaultCase<Cases>();
+    static constexpr bool kHasDefaultCase = kDefaultIndex >= 0;
+    static constexpr bool kHasReturnType = !std::is_same_v<RETURN_TYPE, void>;
+
+    // Static assertions
+    static constexpr bool kDefaultIsOK =
+        kDefaultIndex == -1 || kDefaultIndex == static_cast<int>(std::tuple_size_v<Cases> - 1);
+    static constexpr bool kReturnIsOK =
+        kHasDefaultCase || !kHasReturnType || std::is_constructible_v<RETURN_TYPE>;
+    static_assert(kDefaultIsOK, "Default case must be last in Switch()");
+    static_assert(kReturnIsOK,
+                  "Switch() requires either a Default case or a return type that is either void or "
+                  "default-constructable");
+
+    // If the static asserts have fired, don't bother spewing more errors below
+    static constexpr bool kAllOK = kDefaultIsOK && kReturnIsOK;
+    if constexpr (kAllOK) {
+        if (object) {
+            auto* type = &object->TypeInfo();
+            if constexpr (kHasDefaultCase) {
+                // Evaluate non-default cases.
+                if (!detail::NonDefaultCases<T>(object, type, result,
+                                                traits::Slice<0, kDefaultIndex>(cases))) {
+                    // Nothing matched. Evaluate default case.
+                    if constexpr (kHasReturnType) {
+                        new (result) RETURN_TYPE(
+                            static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})));
+                    } else {
+                        std::get<kDefaultIndex>(cases)({});
+                    }
+                }
+            } else {
+                if (!detail::NonDefaultCases<T>(object, type, result, std::move(cases))) {
+                    // Nothing matched. No default case.
+                    if constexpr (kHasReturnType) {
+                        new (result) RETURN_TYPE();
+                    }
+                }
+            }
+        } else {
+            // Object is nullptr, so no cases can match
+            if constexpr (kHasDefaultCase) {
+                // Evaluate default case.
+                if constexpr (kHasReturnType) {
+                    new (result)
+                        RETURN_TYPE(static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})));
+                } else {
+                    std::get<kDefaultIndex>(cases)({});
+                }
+            } else {
+                // No default case, no case can match.
+                if constexpr (kHasReturnType) {
+                    new (result) RETURN_TYPE();
+                }
+            }
+        }
+    }
+}
+
+/// Resolves to T if T is not nullptr_t, otherwise resolves to Ignore.
+template <typename T>
+using NullptrToIgnore = std::conditional_t<std::is_same_v<T, std::nullptr_t>, Ignore, T>;
+
+/// Resolves to `const TYPE` if any of `CASE_RETURN_TYPES` are const or pointer-to-const, otherwise
+/// resolves to TYPE.
+template <typename TYPE, typename... CASE_RETURN_TYPES>
+using PropagateReturnConst = std::conditional_t<
+    // Are any of the pointer-stripped types const?
+    (std::is_const_v<std::remove_pointer_t<CASE_RETURN_TYPES>> || ...),
+    const TYPE,  // Yes: Apply const to TYPE
+    TYPE>;       // No:  Passthrough
+
+/// SwitchReturnTypeImpl is the implementation of SwitchReturnType
+template <bool IS_CASTABLE, typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl;
+
+/// SwitchReturnTypeImpl specialization for non-castable case types and an explicitly specified
+/// return type.
+template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl</*IS_CASTABLE*/ false, REQUESTED_TYPE, CASE_RETURN_TYPES...> {
+    /// Resolves to `REQUESTED_TYPE`
+    using type = REQUESTED_TYPE;
+};
+
+/// SwitchReturnTypeImpl specialization for non-castable case types and an inferred return type.
+template <typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl</*IS_CASTABLE*/ false, Infer, CASE_RETURN_TYPES...> {
+    /// Resolves to the common type for all the cases return types.
+    using type = std::common_type_t<CASE_RETURN_TYPES...>;
+};
+
+/// SwitchReturnTypeImpl specialization for castable case types and an explicitly specified return
+/// type.
+template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl</*IS_CASTABLE*/ true, REQUESTED_TYPE, CASE_RETURN_TYPES...> {
+  public:
+    /// Resolves to `const REQUESTED_TYPE*` or `REQUESTED_TYPE*`
+    using type = PropagateReturnConst<std::remove_pointer_t<REQUESTED_TYPE>, CASE_RETURN_TYPES...>*;
+};
+
+/// SwitchReturnTypeImpl specialization for castable case types and an inferred return type.
+template <typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl</*IS_CASTABLE*/ true, Infer, CASE_RETURN_TYPES...> {
+  private:
+    using InferredType =
+        CastableCommonBase<detail::NullptrToIgnore<std::remove_pointer_t<CASE_RETURN_TYPES>>...>;
+
+  public:
+    /// `const T*` or `T*`, where T is the common base type for all the castable case types.
+    using type = PropagateReturnConst<InferredType, CASE_RETURN_TYPES...>*;
+};
+
+/// Resolves to the return type for a Switch() with the requested return type `REQUESTED_TYPE` and
+/// case statement return types. If `REQUESTED_TYPE` is Infer then the return type will be inferred
+/// from the case return types.
+template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
+using SwitchReturnType = typename SwitchReturnTypeImpl<
+    IsCastable<NullptrToIgnore<std::remove_pointer_t<CASE_RETURN_TYPES>>...>,
+    REQUESTED_TYPE,
+    CASE_RETURN_TYPES...>::type;
+
+}  // namespace tint::detail
+
+namespace tint {
+
+/// Switch is used to dispatch one of the provided callback case handler functions based on the type
+/// of `object` and the parameter type of the case handlers. Switch will sequentially check the type
+/// of `object` against each of the switch case handler functions, and will invoke the first case
+/// handler function which has a parameter type that matches the object type. When a case handler is
+/// matched, it will be called with the single argument of `object` cast to the case handler's
+/// parameter type. Switch will invoke at most one case handler. Each of the case functions must
+/// have the signature `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R`
+/// is the return type, consistent across all case handlers.
+///
+/// An optional default case function with the signature `R(Default)` can be used as the last case.
+/// This default case will be called if all previous cases failed to match.
+///
+/// If `object` is nullptr and a default case is provided, then the default case will be called. If
+/// `object` is nullptr and no default case is provided, then no cases will be called.
+///
+/// Example:
+/// ```
+/// Switch(object,
+///     [&](TypeA*) { /* ... */ },
+///     [&](TypeB*) { /* ... */ });
+///
+/// Switch(object,
+///     [&](TypeA*) { /* ... */ },
+///     [&](TypeB*) { /* ... */ },
+///     [&](Default) { /* Called if object is not TypeA or TypeB */ });
+/// ```
+///
+/// @param object the object who's type is used to
+/// @param cases the switch cases
+/// @return the value returned by the called case. If no cases matched, then the zero value for the
+/// consistent case type.
+template <typename RETURN_TYPE = detail::Infer, typename T = CastableBase, typename... CASES>
+inline auto Switch(T* object, CASES&&... cases) {
+    using ReturnType = detail::SwitchReturnType<RETURN_TYPE, traits::ReturnType<CASES>...>;
+    static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
+
+    if constexpr (kHasReturnType) {
+        // Replacement for std::aligned_storage as this is broken on earlier versions of MSVC.
+        struct alignas(alignof(ReturnType)) ReturnStorage {
+            uint8_t data[sizeof(ReturnType)];
+        };
+        ReturnStorage storage;
+        auto* res = utils::Bitcast<ReturnType*>(&storage);
+        TINT_DEFER(res->~ReturnType());
+        detail::SwitchCases(object, res, std::forward_as_tuple(std::forward<CASES>(cases)...));
+        return *res;
+    } else {
+        detail::SwitchCases<T, void>(object, nullptr,
+                                     std::forward_as_tuple(std::forward<CASES>(cases)...));
+    }
+}
+
+}  // namespace tint
+
+#endif  // SRC_TINT_SWITCH_H_
diff --git a/src/tint/castable_bench.cc b/src/tint/switch_bench.cc
similarity index 99%
rename from src/tint/castable_bench.cc
rename to src/tint/switch_bench.cc
index c9d0c43..ea6098b 100644
--- a/src/tint/castable_bench.cc
+++ b/src/tint/switch_bench.cc
@@ -16,7 +16,7 @@
 
 #include "benchmark/benchmark.h"
 
-#include "src/tint/castable.h"
+#include "src/tint/switch.h"
 
 namespace tint {
 namespace {
diff --git a/src/tint/switch_test.cc b/src/tint/switch_test.cc
new file mode 100644
index 0000000..d93f830
--- /dev/null
+++ b/src/tint/switch_test.cc
@@ -0,0 +1,552 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/switch.h"
+
+#include <memory>
+#include <string>
+
+#include "gtest/gtest.h"
+
+namespace tint {
+namespace {
+
+struct Animal : public tint::Castable<Animal> {};
+struct Amphibian : public tint::Castable<Amphibian, Animal> {};
+struct Mammal : public tint::Castable<Mammal, Animal> {};
+struct Reptile : public tint::Castable<Reptile, Animal> {};
+struct Frog : public tint::Castable<Frog, Amphibian> {};
+struct Bear : public tint::Castable<Bear, Mammal> {};
+struct Lizard : public tint::Castable<Lizard, Reptile> {};
+struct Gecko : public tint::Castable<Gecko, Lizard> {};
+struct Iguana : public tint::Castable<Iguana, Lizard> {};
+
+TEST(Castable, SwitchNoDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        bool frog_matched_amphibian = false;
+        Switch(
+            frog.get(),  //
+            [&](Reptile*) { FAIL() << "frog is not reptile"; },
+            [&](Mammal*) { FAIL() << "frog is not mammal"; },
+            [&](Amphibian* amphibian) {
+                EXPECT_EQ(amphibian, frog.get());
+                frog_matched_amphibian = true;
+            });
+        EXPECT_TRUE(frog_matched_amphibian);
+    }
+    {
+        bool bear_matched_mammal = false;
+        Switch(
+            bear.get(),  //
+            [&](Reptile*) { FAIL() << "bear is not reptile"; },
+            [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+            [&](Mammal* mammal) {
+                EXPECT_EQ(mammal, bear.get());
+                bear_matched_mammal = true;
+            });
+        EXPECT_TRUE(bear_matched_mammal);
+    }
+    {
+        bool gecko_matched_reptile = false;
+        Switch(
+            gecko.get(),  //
+            [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+            [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+            [&](Reptile* reptile) {
+                EXPECT_EQ(reptile, gecko.get());
+                gecko_matched_reptile = true;
+            });
+        EXPECT_TRUE(gecko_matched_reptile);
+    }
+}
+
+TEST(Castable, SwitchWithUnusedDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        bool frog_matched_amphibian = false;
+        Switch(
+            frog.get(),  //
+            [&](Reptile*) { FAIL() << "frog is not reptile"; },
+            [&](Mammal*) { FAIL() << "frog is not mammal"; },
+            [&](Amphibian* amphibian) {
+                EXPECT_EQ(amphibian, frog.get());
+                frog_matched_amphibian = true;
+            },
+            [&](Default) { FAIL() << "default should not have been selected"; });
+        EXPECT_TRUE(frog_matched_amphibian);
+    }
+    {
+        bool bear_matched_mammal = false;
+        Switch(
+            bear.get(),  //
+            [&](Reptile*) { FAIL() << "bear is not reptile"; },
+            [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+            [&](Mammal* mammal) {
+                EXPECT_EQ(mammal, bear.get());
+                bear_matched_mammal = true;
+            },
+            [&](Default) { FAIL() << "default should not have been selected"; });
+        EXPECT_TRUE(bear_matched_mammal);
+    }
+    {
+        bool gecko_matched_reptile = false;
+        Switch(
+            gecko.get(),  //
+            [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+            [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+            [&](Reptile* reptile) {
+                EXPECT_EQ(reptile, gecko.get());
+                gecko_matched_reptile = true;
+            },
+            [&](Default) { FAIL() << "default should not have been selected"; });
+        EXPECT_TRUE(gecko_matched_reptile);
+    }
+}
+
+TEST(Castable, SwitchDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        bool frog_matched_default = false;
+        Switch(
+            frog.get(),  //
+            [&](Reptile*) { FAIL() << "frog is not reptile"; },
+            [&](Mammal*) { FAIL() << "frog is not mammal"; },
+            [&](Default) { frog_matched_default = true; });
+        EXPECT_TRUE(frog_matched_default);
+    }
+    {
+        bool bear_matched_default = false;
+        Switch(
+            bear.get(),  //
+            [&](Reptile*) { FAIL() << "bear is not reptile"; },
+            [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+            [&](Default) { bear_matched_default = true; });
+        EXPECT_TRUE(bear_matched_default);
+    }
+    {
+        bool gecko_matched_default = false;
+        Switch(
+            gecko.get(),  //
+            [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+            [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+            [&](Default) { gecko_matched_default = true; });
+        EXPECT_TRUE(gecko_matched_default);
+    }
+}
+
+TEST(Castable, SwitchMatchFirst) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    {
+        bool frog_matched_animal = false;
+        Switch(
+            frog.get(),
+            [&](Animal* animal) {
+                EXPECT_EQ(animal, frog.get());
+                frog_matched_animal = true;
+            },
+            [&](Amphibian*) { FAIL() << "animal should have been matched first"; });
+        EXPECT_TRUE(frog_matched_animal);
+    }
+    {
+        bool frog_matched_amphibian = false;
+        Switch(
+            frog.get(),
+            [&](Amphibian* amphibain) {
+                EXPECT_EQ(amphibain, frog.get());
+                frog_matched_amphibian = true;
+            },
+            [&](Animal*) { FAIL() << "amphibian should have been matched first"; });
+        EXPECT_TRUE(frog_matched_amphibian);
+    }
+}
+
+TEST(Castable, SwitchReturnValueWithDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        const char* result = Switch(
+            frog.get(),                              //
+            [](Mammal*) { return "mammal"; },        //
+            [](Amphibian*) { return "amphibian"; },  //
+            [](Default) { return "unknown"; });
+        static_assert(std::is_same_v<decltype(result), const char*>);
+        EXPECT_EQ(std::string(result), "amphibian");
+    }
+    {
+        const char* result = Switch(
+            bear.get(),                              //
+            [](Mammal*) { return "mammal"; },        //
+            [](Amphibian*) { return "amphibian"; },  //
+            [](Default) { return "unknown"; });
+        static_assert(std::is_same_v<decltype(result), const char*>);
+        EXPECT_EQ(std::string(result), "mammal");
+    }
+    {
+        const char* result = Switch(
+            gecko.get(),                             //
+            [](Mammal*) { return "mammal"; },        //
+            [](Amphibian*) { return "amphibian"; },  //
+            [](Default) { return "unknown"; });
+        static_assert(std::is_same_v<decltype(result), const char*>);
+        EXPECT_EQ(std::string(result), "unknown");
+    }
+}
+
+TEST(Castable, SwitchReturnValueWithoutDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        const char* result = Switch(
+            frog.get(),                        //
+            [](Mammal*) { return "mammal"; },  //
+            [](Amphibian*) { return "amphibian"; });
+        static_assert(std::is_same_v<decltype(result), const char*>);
+        EXPECT_EQ(std::string(result), "amphibian");
+    }
+    {
+        const char* result = Switch(
+            bear.get(),                        //
+            [](Mammal*) { return "mammal"; },  //
+            [](Amphibian*) { return "amphibian"; });
+        static_assert(std::is_same_v<decltype(result), const char*>);
+        EXPECT_EQ(std::string(result), "mammal");
+    }
+    {
+        auto* result = Switch(
+            gecko.get(),                       //
+            [](Mammal*) { return "mammal"; },  //
+            [](Amphibian*) { return "amphibian"; });
+        static_assert(std::is_same_v<decltype(result), const char*>);
+        EXPECT_EQ(result, nullptr);
+    }
+}
+
+TEST(Castable, SwitchInferPODReturnTypeWithDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        auto result = Switch(
+            frog.get(),                       //
+            [](Mammal*) { return 1; },        //
+            [](Amphibian*) { return 2.0f; },  //
+            [](Default) { return 3.0; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 2.0);
+    }
+    {
+        auto result = Switch(
+            bear.get(),                       //
+            [](Mammal*) { return 1.0; },      //
+            [](Amphibian*) { return 2.0f; },  //
+            [](Default) { return 3; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 1.0);
+    }
+    {
+        auto result = Switch(
+            gecko.get(),                   //
+            [](Mammal*) { return 1.0f; },  //
+            [](Amphibian*) { return 2; },  //
+            [](Default) { return 3.0; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 3.0);
+    }
+}
+
+TEST(Castable, SwitchInferPODReturnTypeWithoutDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        auto result = Switch(
+            frog.get(),                 //
+            [](Mammal*) { return 1; },  //
+            [](Amphibian*) { return 2.0f; });
+        static_assert(std::is_same_v<decltype(result), float>);
+        EXPECT_EQ(result, 2.0f);
+    }
+    {
+        auto result = Switch(
+            bear.get(),                    //
+            [](Mammal*) { return 1.0f; },  //
+            [](Amphibian*) { return 2; });
+        static_assert(std::is_same_v<decltype(result), float>);
+        EXPECT_EQ(result, 1.0f);
+    }
+    {
+        auto result = Switch(
+            gecko.get(),                  //
+            [](Mammal*) { return 1.0; },  //
+            [](Amphibian*) { return 2.0f; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 0.0);
+    }
+}
+
+TEST(Castable, SwitchInferCastableReturnTypeWithDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        auto* result = Switch(
+            frog.get(),                          //
+            [](Mammal* p) { return p; },         //
+            [](Amphibian*) { return nullptr; },  //
+            [](Default) { return nullptr; });
+        static_assert(std::is_same_v<decltype(result), Mammal*>);
+        EXPECT_EQ(result, nullptr);
+    }
+    {
+        auto* result = Switch(
+            bear.get(),                   //
+            [](Mammal* p) { return p; },  //
+            [](Amphibian* p) { return const_cast<const Amphibian*>(p); },
+            [](Default) { return nullptr; });
+        static_assert(std::is_same_v<decltype(result), const Animal*>);
+        EXPECT_EQ(result, bear.get());
+    }
+    {
+        auto* result = Switch(
+            gecko.get(),                     //
+            [](Mammal* p) { return p; },     //
+            [](Amphibian* p) { return p; },  //
+            [](Default) -> CastableBase* { return nullptr; });
+        static_assert(std::is_same_v<decltype(result), CastableBase*>);
+        EXPECT_EQ(result, nullptr);
+    }
+}
+
+TEST(Castable, SwitchInferCastableReturnTypeWithoutDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        auto* result = Switch(
+            frog.get(),                   //
+            [](Mammal* p) { return p; },  //
+            [](Amphibian*) { return nullptr; });
+        static_assert(std::is_same_v<decltype(result), Mammal*>);
+        EXPECT_EQ(result, nullptr);
+    }
+    {
+        auto* result = Switch(
+            bear.get(),                                                     //
+            [](Mammal* p) { return p; },                                    //
+            [](Amphibian* p) { return const_cast<const Amphibian*>(p); });  //
+        static_assert(std::is_same_v<decltype(result), const Animal*>);
+        EXPECT_EQ(result, bear.get());
+    }
+    {
+        auto* result = Switch(
+            gecko.get(),                  //
+            [](Mammal* p) { return p; },  //
+            [](Amphibian* p) { return p; });
+        static_assert(std::is_same_v<decltype(result), Animal*>);
+        EXPECT_EQ(result, nullptr);
+    }
+}
+
+TEST(Castable, SwitchExplicitPODReturnTypeWithDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        auto result = Switch<double>(
+            frog.get(),                       //
+            [](Mammal*) { return 1; },        //
+            [](Amphibian*) { return 2.0f; },  //
+            [](Default) { return 3.0; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 2.0f);
+    }
+    {
+        auto result = Switch<double>(
+            bear.get(),                    //
+            [](Mammal*) { return 1; },     //
+            [](Amphibian*) { return 2; },  //
+            [](Default) { return 3; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 1.0f);
+    }
+    {
+        auto result = Switch<double>(
+            gecko.get(),                      //
+            [](Mammal*) { return 1.0f; },     //
+            [](Amphibian*) { return 2.0f; },  //
+            [](Default) { return 3.0f; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 3.0f);
+    }
+}
+
+TEST(Castable, SwitchExplicitPODReturnTypeWithoutDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        auto result = Switch<double>(
+            frog.get(),                 //
+            [](Mammal*) { return 1; },  //
+            [](Amphibian*) { return 2.0f; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 2.0f);
+    }
+    {
+        auto result = Switch<double>(
+            bear.get(),                    //
+            [](Mammal*) { return 1.0f; },  //
+            [](Amphibian*) { return 2; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 1.0f);
+    }
+    {
+        auto result = Switch<double>(
+            gecko.get(),                  //
+            [](Mammal*) { return 1.0; },  //
+            [](Amphibian*) { return 2.0f; });
+        static_assert(std::is_same_v<decltype(result), double>);
+        EXPECT_EQ(result, 0.0);
+    }
+}
+
+TEST(Castable, SwitchExplicitCastableReturnTypeWithDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        auto* result = Switch<Animal>(
+            frog.get(),                          //
+            [](Mammal* p) { return p; },         //
+            [](Amphibian*) { return nullptr; },  //
+            [](Default) { return nullptr; });
+        static_assert(std::is_same_v<decltype(result), Animal*>);
+        EXPECT_EQ(result, nullptr);
+    }
+    {
+        auto* result = Switch<CastableBase>(
+            bear.get(),                   //
+            [](Mammal* p) { return p; },  //
+            [](Amphibian* p) { return const_cast<const Amphibian*>(p); },
+            [](Default) { return nullptr; });
+        static_assert(std::is_same_v<decltype(result), const CastableBase*>);
+        EXPECT_EQ(result, bear.get());
+    }
+    {
+        auto* result = Switch<const Animal>(
+            gecko.get(),                     //
+            [](Mammal* p) { return p; },     //
+            [](Amphibian* p) { return p; },  //
+            [](Default) { return nullptr; });
+        static_assert(std::is_same_v<decltype(result), const Animal*>);
+        EXPECT_EQ(result, nullptr);
+    }
+}
+
+TEST(Castable, SwitchExplicitCastableReturnTypeWithoutDefault) {
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+    std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+    {
+        auto* result = Switch<Animal>(
+            frog.get(),                   //
+            [](Mammal* p) { return p; },  //
+            [](Amphibian*) { return nullptr; });
+        static_assert(std::is_same_v<decltype(result), Animal*>);
+        EXPECT_EQ(result, nullptr);
+    }
+    {
+        auto* result = Switch<CastableBase>(
+            bear.get(),                                                     //
+            [](Mammal* p) { return p; },                                    //
+            [](Amphibian* p) { return const_cast<const Amphibian*>(p); });  //
+        static_assert(std::is_same_v<decltype(result), const CastableBase*>);
+        EXPECT_EQ(result, bear.get());
+    }
+    {
+        auto* result = Switch<const Animal*>(
+            gecko.get(),                  //
+            [](Mammal* p) { return p; },  //
+            [](Amphibian* p) { return p; });
+        static_assert(std::is_same_v<decltype(result), const Animal*>);
+        EXPECT_EQ(result, nullptr);
+    }
+}
+
+TEST(Castable, SwitchNull) {
+    Animal* null = nullptr;
+    Switch(
+        null,  //
+        [&](Amphibian*) { FAIL() << "should not be called"; },
+        [&](Animal*) { FAIL() << "should not be called"; });
+}
+
+TEST(Castable, SwitchNullNoDefault) {
+    Animal* null = nullptr;
+    bool default_called = false;
+    Switch(
+        null,  //
+        [&](Amphibian*) { FAIL() << "should not be called"; },
+        [&](Animal*) { FAIL() << "should not be called"; },
+        [&](Default) { default_called = true; });
+    EXPECT_TRUE(default_called);
+}
+
+TEST(Castable, SwitchReturnNoDefaultInitializer) {
+    struct Object {
+        explicit Object(int v) : value(v) {}
+        int value;
+    };
+
+    std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+    {
+        auto result = Switch(
+            frog.get(),                            //
+            [](Mammal*) { return Object(1); },     //
+            [](Amphibian*) { return Object(2); },  //
+            [](Default) { return Object(3); });
+        static_assert(std::is_same_v<decltype(result), Object>);
+        EXPECT_EQ(result.value, 2);
+    }
+    {
+        auto result = Switch(
+            frog.get(),                         //
+            [](Mammal*) { return Object(1); },  //
+            [](Default) { return Object(3); });
+        static_assert(std::is_same_v<decltype(result), Object>);
+        EXPECT_EQ(result.value, 3);
+    }
+}
+
+}  // namespace
+
+TINT_INSTANTIATE_TYPEINFO(Animal);
+TINT_INSTANTIATE_TYPEINFO(Amphibian);
+TINT_INSTANTIATE_TYPEINFO(Mammal);
+TINT_INSTANTIATE_TYPEINFO(Reptile);
+TINT_INSTANTIATE_TYPEINFO(Frog);
+TINT_INSTANTIATE_TYPEINFO(Bear);
+TINT_INSTANTIATE_TYPEINFO(Lizard);
+TINT_INSTANTIATE_TYPEINFO(Gecko);
+
+}  // namespace tint
diff --git a/src/tint/transform/builtin_polyfill.cc b/src/tint/transform/builtin_polyfill.cc
index f66025c..ba7c5d2 100644
--- a/src/tint/transform/builtin_polyfill.cc
+++ b/src/tint/transform/builtin_polyfill.cc
@@ -23,6 +23,7 @@
 #include "src/tint/sem/builtin.h"
 #include "src/tint/sem/call.h"
 #include "src/tint/sem/type_expression.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/storage_texture.h"
 #include "src/tint/type/texture_dimension.h"
 #include "src/tint/utils/map.h"
diff --git a/src/tint/transform/calculate_array_length.cc b/src/tint/transform/calculate_array_length.cc
index 36ad192..743a992 100644
--- a/src/tint/transform/calculate_array_length.cc
+++ b/src/tint/transform/calculate_array_length.cc
@@ -26,6 +26,7 @@
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/struct.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/transform/simplify_pointers.h"
 #include "src/tint/type/reference.h"
 #include "src/tint/utils/hash.h"
diff --git a/src/tint/transform/decompose_memory_access.cc b/src/tint/transform/decompose_memory_access.cc
index 2ff62b4..f2146f3 100644
--- a/src/tint/transform/decompose_memory_access.cc
+++ b/src/tint/transform/decompose_memory_access.cc
@@ -30,6 +30,7 @@
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/struct.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/array.h"
 #include "src/tint/type/atomic.h"
 #include "src/tint/type/reference.h"
diff --git a/src/tint/transform/demote_to_helper.cc b/src/tint/transform/demote_to_helper.cc
index 1cb6f96..c853ef0 100644
--- a/src/tint/transform/demote_to_helper.cc
+++ b/src/tint/transform/demote_to_helper.cc
@@ -23,6 +23,7 @@
 #include "src/tint/sem/call.h"
 #include "src/tint/sem/function.h"
 #include "src/tint/sem/statement.h"
+#include "src/tint/switch.h"
 #include "src/tint/transform/utils/hoist_to_decl_before.h"
 #include "src/tint/type/reference.h"
 #include "src/tint/utils/map.h"
diff --git a/src/tint/transform/merge_return.cc b/src/tint/transform/merge_return.cc
index 353fb13..f9f2feb 100644
--- a/src/tint/transform/merge_return.cc
+++ b/src/tint/transform/merge_return.cc
@@ -18,6 +18,7 @@
 
 #include "src/tint/program_builder.h"
 #include "src/tint/sem/statement.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/scoped_assignment.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::MergeReturn);
diff --git a/src/tint/transform/packed_vec3.cc b/src/tint/transform/packed_vec3.cc
index 956adcb..8a23221 100644
--- a/src/tint/transform/packed_vec3.cc
+++ b/src/tint/transform/packed_vec3.cc
@@ -27,6 +27,7 @@
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/type_expression.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/array.h"
 #include "src/tint/type/reference.h"
 #include "src/tint/type/vector.h"
diff --git a/src/tint/transform/preserve_padding.cc b/src/tint/transform/preserve_padding.cc
index fbe785c..59fd786 100644
--- a/src/tint/transform/preserve_padding.cc
+++ b/src/tint/transform/preserve_padding.cc
@@ -19,6 +19,7 @@
 
 #include "src/tint/program_builder.h"
 #include "src/tint/sem/struct.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/reference.h"
 #include "src/tint/utils/map.h"
 #include "src/tint/utils/vector.h"
diff --git a/src/tint/transform/renamer.cc b/src/tint/transform/renamer.cc
index 3ddb6c6..708c6ec 100644
--- a/src/tint/transform/renamer.cc
+++ b/src/tint/transform/renamer.cc
@@ -24,6 +24,7 @@
 #include "src/tint/sem/type_expression.h"
 #include "src/tint/sem/value_constructor.h"
 #include "src/tint/sem/value_conversion.h"
+#include "src/tint/switch.h"
 #include "src/tint/text/unicode.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::Renamer);
diff --git a/src/tint/transform/robustness.cc b/src/tint/transform/robustness.cc
index b7a46f5..2eb74a1 100644
--- a/src/tint/transform/robustness.cc
+++ b/src/tint/transform/robustness.cc
@@ -28,6 +28,7 @@
 #include "src/tint/sem/member_accessor_expression.h"
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/value_expression.h"
+#include "src/tint/switch.h"
 #include "src/tint/transform/utils/hoist_to_decl_before.h"
 #include "src/tint/type/reference.h"
 
diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc
index e9c719e..bec9019 100644
--- a/src/tint/transform/simplify_pointers.cc
+++ b/src/tint/transform/simplify_pointers.cc
@@ -24,6 +24,7 @@
 #include "src/tint/sem/function.h"
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/transform/unshadow.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::SimplifyPointers);
diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc
index 8a8daeb..8be90bb 100644
--- a/src/tint/transform/single_entry_point.cc
+++ b/src/tint/transform/single_entry_point.cc
@@ -20,6 +20,7 @@
 #include "src/tint/program_builder.h"
 #include "src/tint/sem/function.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint);
 TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint::Config);
diff --git a/src/tint/transform/spirv_atomic.cc b/src/tint/transform/spirv_atomic.cc
index 73c2ae2..f41fbb7 100644
--- a/src/tint/transform/spirv_atomic.cc
+++ b/src/tint/transform/spirv_atomic.cc
@@ -26,6 +26,7 @@
 #include "src/tint/sem/index_accessor_expression.h"
 #include "src/tint/sem/member_accessor_expression.h"
 #include "src/tint/sem/statement.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/reference.h"
 #include "src/tint/utils/map.h"
 #include "src/tint/utils/unique_vector.h"
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
index 7a9cffc..bfd8c80 100644
--- a/src/tint/transform/std140.cc
+++ b/src/tint/transform/std140.cc
@@ -25,6 +25,7 @@
 #include "src/tint/sem/module.h"
 #include "src/tint/sem/struct.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/compiler_macros.h"
 #include "src/tint/utils/hashmap.h"
 #include "src/tint/utils/transform.h"
diff --git a/src/tint/transform/substitute_override.cc b/src/tint/transform/substitute_override.cc
index e4b2d7b..04a9da0 100644
--- a/src/tint/transform/substitute_override.cc
+++ b/src/tint/transform/substitute_override.cc
@@ -22,6 +22,7 @@
 #include "src/tint/sem/builtin.h"
 #include "src/tint/sem/index_accessor_expression.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride);
 TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride::Config);
diff --git a/src/tint/transform/texture_1d_to_2d.cc b/src/tint/transform/texture_1d_to_2d.cc
index 5050aa0..7c48db5 100644
--- a/src/tint/transform/texture_1d_to_2d.cc
+++ b/src/tint/transform/texture_1d_to_2d.cc
@@ -20,6 +20,7 @@
 #include "src/tint/sem/function.h"
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/type_expression.h"
+#include "src/tint/switch.h"
 #include "src/tint/type/texture_dimension.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::Texture1DTo2D);
diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc
index 8f79544..9639cfc 100644
--- a/src/tint/transform/unshadow.cc
+++ b/src/tint/transform/unshadow.cc
@@ -23,6 +23,7 @@
 #include "src/tint/sem/function.h"
 #include "src/tint/sem/statement.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::Unshadow);
 
diff --git a/src/tint/transform/utils/get_insertion_point.cc b/src/tint/transform/utils/get_insertion_point.cc
index d10d134..bce7a7e 100644
--- a/src/tint/transform/utils/get_insertion_point.cc
+++ b/src/tint/transform/utils/get_insertion_point.cc
@@ -16,6 +16,7 @@
 #include "src/tint/debug.h"
 #include "src/tint/diagnostic/diagnostic.h"
 #include "src/tint/sem/for_loop_statement.h"
+#include "src/tint/switch.h"
 
 namespace tint::transform::utils {
 
diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc
index e07dcfe..d887198 100644
--- a/src/tint/transform/vertex_pulling.cc
+++ b/src/tint/transform/vertex_pulling.cc
@@ -23,6 +23,7 @@
 #include "src/tint/builtin/builtin_value.h"
 #include "src/tint/program_builder.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/compiler_macros.h"
 #include "src/tint/utils/map.h"
 #include "src/tint/utils/math.h"
diff --git a/src/tint/type/type.cc b/src/tint/type/type.cc
index 6163a9a..7ed62fa 100644
--- a/src/tint/type/type.cc
+++ b/src/tint/type/type.cc
@@ -14,6 +14,7 @@
 
 #include "src/tint/type/type.h"
 
+#include "src/tint/switch.h"
 #include "src/tint/type/abstract_float.h"
 #include "src/tint/type/abstract_int.h"
 #include "src/tint/type/array.h"
diff --git a/src/tint/writer/append_vector.cc b/src/tint/writer/append_vector.cc
index 98d9798..b66f9fd 100644
--- a/src/tint/writer/append_vector.cc
+++ b/src/tint/writer/append_vector.cc
@@ -21,6 +21,7 @@
 #include "src/tint/sem/value_constructor.h"
 #include "src/tint/sem/value_conversion.h"
 #include "src/tint/sem/value_expression.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/transform.h"
 
 using namespace tint::number_suffixes;  // NOLINT
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 61b281f..ec1f541 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -40,6 +40,7 @@
 #include "src/tint/sem/value_constructor.h"
 #include "src/tint/sem/value_conversion.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/transform/add_block_attribute.h"
 #include "src/tint/transform/add_empty_entry_point.h"
 #include "src/tint/transform/binding_remapper.h"
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index ada0399..cb3630e 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -40,6 +40,7 @@
 #include "src/tint/sem/value_constructor.h"
 #include "src/tint/sem/value_conversion.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/transform/add_empty_entry_point.h"
 #include "src/tint/transform/array_length_from_uniform.h"
 #include "src/tint/transform/builtin_polyfill.h"
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 92bdb96..3210370 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -40,6 +40,7 @@
 #include "src/tint/sem/value_constructor.h"
 #include "src/tint/sem/value_conversion.h"
 #include "src/tint/sem/variable.h"
+#include "src/tint/switch.h"
 #include "src/tint/transform/array_length_from_uniform.h"
 #include "src/tint/transform/builtin_polyfill.h"
 #include "src/tint/transform/canonicalize_entry_point_io.h"
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 69aaabd..0a167eb 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -34,6 +34,7 @@
 #include "src/tint/ast/workgroup_attribute.h"
 #include "src/tint/sem/struct.h"
 #include "src/tint/sem/switch_statement.h"
+#include "src/tint/switch.h"
 #include "src/tint/utils/math.h"
 #include "src/tint/utils/scoped_assignment.h"
 #include "src/tint/writer/float_to_string.h"