Add tint::Switch()

A type dispatch helper with replaces chains of:

  if (auto* a = obj->As<A>()) {
  } else if (auto* b = obj->As<B>()) {
  } else {


    [&](A* a) { ... },
    [&](B* b) { ... },
    [&](Default) { ... });

This new helper provides greater opportunities for optimizations, avoids
scoping issues with if-else blocks, and is slightly cleaner (IMO).

Bug: tint:1383
Change-Id: Ice469a03342ef57cbcf65f69753e4b528ac50137
Kokoro: Kokoro <>
Reviewed-by: David Neto <>
Commit-Queue: Ben Clayton <>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 8c0f20a..600e91f 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -1160,6 +1160,7 @@
+    ""
diff --git a/src/ast/ b/src/ast/
index 24999d2..3f06a31 100644
--- a/src/ast/
+++ b/src/ast/
@@ -35,16 +35,15 @@
-    if (auto* ty = decl->As<ast::TypeDecl>()) {
-      type_decls_.push_back(ty);
-    } else if (auto* func = decl->As<Function>()) {
-      functions_.push_back(func);
-    } else if (auto* var = decl->As<Variable>()) {
-      global_variables_.push_back(var);
-    } else {
-      diag::List diagnostics;
-      TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
-    }
+    Switch(
+        decl,  //
+        [&](const ast::TypeDecl* type) { type_decls_.push_back(type); },
+        [&](const Function* func) { functions_.push_back(func); },
+        [&](const Variable* var) { global_variables_.push_back(var); },
+        [&](Default) {
+          diag::List diagnostics;
+          TINT_ICE(AST, diagnostics) << "Unknown global declaration type";
+        });
@@ -101,19 +100,24 @@
           << "src global declaration was nullptr";
-    if (auto* type = decl->As<ast::TypeDecl>()) {
-      type_decls_.push_back(type);
-    } else if (auto* func = decl->As<Function>()) {
-      functions_.push_back(func);
-    } else if (auto* var = decl->As<Variable>()) {
-      global_variables_.push_back(var);
-    } else {
-      TINT_ICE(AST, ctx->dst->Diagnostics())
-          << "Unknown global declaration type";
-    }
+    Switch(
+        decl,
+        [&](const ast::TypeDecl* type) {
+          TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
+          type_decls_.push_back(type);
+        },
+        [&](const Function* func) {
+          TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
+          functions_.push_back(func);
+        },
+        [&](const Variable* var) {
+          TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, var, program_id);
+          global_variables_.push_back(var);
+        },
+        [&](Default) {
+          TINT_ICE(AST, ctx->dst->Diagnostics())
+              << "Unknown global declaration type";
+        });
diff --git a/src/ast/traverse_expressions.h b/src/ast/traverse_expressions.h
index 88d3dfc..b578941 100644
--- a/src/ast/traverse_expressions.h
+++ b/src/ast/traverse_expressions.h
@@ -101,30 +101,47 @@
-    if (auto* idx = expr->As<IndexAccessorExpression>()) {
-      push_pair(idx->object, idx->index);
-    } else if (auto* bin_op = expr->As<BinaryExpression>()) {
-      push_pair(bin_op->lhs, bin_op->rhs);
-    } else if (auto* bitcast = expr->As<BitcastExpression>()) {
-      to_visit.push_back(bitcast->expr);
-    } else if (auto* call = expr->As<CallExpression>()) {
-      // TODO( Resolver breaks if we actually include the
-      // function name in the traversal.
-      // to_visit.push_back(call->func);
-      push_list(call->args);
-    } else if (auto* member = expr->As<MemberAccessorExpression>()) {
-      // TODO( Resolver breaks if we actually include the
-      // member name in the traversal.
-      // push_pair(member->structure, member->member);
-      to_visit.push_back(member->structure);
-    } else if (auto* unary = expr->As<UnaryOpExpression>()) {
-      to_visit.push_back(unary->expr);
-    } else if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
-                             PhonyExpression>()) {
-      // Leaf expression
-    } else {
-      TINT_ICE(AST, diags) << "unhandled expression type: "
-                           << expr->TypeInfo().name;
+    bool ok = Switch(
+        expr,
+        [&](const IndexAccessorExpression* idx) {
+          push_pair(idx->object, idx->index);
+          return true;
+        },
+        [&](const BinaryExpression* bin_op) {
+          push_pair(bin_op->lhs, bin_op->rhs);
+          return true;
+        },
+        [&](const BitcastExpression* bitcast) {
+          to_visit.push_back(bitcast->expr);
+          return true;
+        },
+        [&](const CallExpression* call) {
+          // TODO( Resolver breaks if we actually include
+          // the function name in the traversal. to_visit.push_back(call->func);
+          push_list(call->args);
+          return true;
+        },
+        [&](const MemberAccessorExpression* member) {
+          // TODO( Resolver breaks if we actually include
+          // the member name in the traversal. push_pair(member->structure,
+          // member->member);
+          to_visit.push_back(member->structure);
+          return true;
+        },
+        [&](const UnaryOpExpression* unary) {
+          to_visit.push_back(unary->expr);
+          return true;
+        },
+        [&](Default) {
+          if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
+                            PhonyExpression>()) {
+            return true;  // Leaf expression
+          }
+          TINT_ICE(AST, diags)
+              << "unhandled expression type: " << expr->TypeInfo().name;
+          return false;
+        });
+    if (!ok) {
       return false;
diff --git a/src/castable.h b/src/castable.h
index 3104492..04f2dcc 100644
--- a/src/castable.h
+++ b/src/castable.h
@@ -453,6 +453,105 @@
+/// Default can be used as the default case for a Switch(), when all previous
+/// cases failed to match.
+/// Example:
+/// ```
+/// Switch(object,
+///     [&](TypeA*) { /* ... */ },
+///     [&](TypeB*) { /* ... */ },
+///     [&](Default) { /* If not TypeA or TypeB */ });
+/// ```
+struct Default {};
+/// Switch is used to dispatch one of the provided callback case handler
+/// functions based on the type of `object` and the parameter type of the case
+/// handlers. Switch will sequentially check the type of `object` against each
+/// of the switch case handler functions, and will invoke the first case handler
+/// function which has a parameter type that matches the object type. When a
+/// case handler is matched, it will be called with the single argument of
+/// `object` cast to the case handler's parameter type. Switch will invoke at
+/// most one case handler. Each of the case functions must have the signature
+/// `R(T*)` or `R(const T*)`, where `T` is the type matched by that case and `R`
+/// is the return type, consistent across all case handlers.
+/// An optional default case function with the signature `R(Default)` can be
+/// used as the last case. This default case will be called if all previous
+/// cases failed to match.
+/// Example:
+/// ```
+/// Switch(object,
+///     [&](TypeA*) { /* ... */ },
+///     [&](TypeB*) { /* ... */ });
+/// Switch(object,
+///     [&](TypeA*) { /* ... */ },
+///     [&](TypeB*) { /* ... */ },
+///     [&](Default) { /* Called if object is not TypeA or TypeB */ });
+/// ```
+/// @param object the object who's type is used to
+/// @param first_case the first switch case
+/// @param other_cases additional switch cases (optional)
+/// @return the value returned by the called case. If no cases matched, then the
+/// zero value for the consistent case type.
+template <typename T, typename FIRST_CASE, typename... OTHER_CASES>
+traits::ReturnType<FIRST_CASE>  //
+Switch(T* object, FIRST_CASE&& first_case, OTHER_CASES&&... other_cases) {
+  using ReturnType = traits::ReturnType<FIRST_CASE>;
+  using CaseType = std::remove_pointer_t<traits::ParameterType<FIRST_CASE, 0>>;
+  static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
+  static_assert(traits::SignatureOfT<FIRST_CASE>::parameter_count == 1,
+                "Switch case must have a single parameter");
+  if constexpr (std::is_same_v<CaseType, Default>) {
+    // Default case. Must be last.
+    (void)object;  // 'object' is not used by the Default case.
+    static_assert(sizeof...(other_cases) == 0,
+                  "Switch Default case must come last");
+    if constexpr (kHasReturnType) {
+      return first_case({});
+    } else {
+      first_case({});
+      return;
+    }
+  } else {
+    // Regular case.
+    static_assert(traits::IsTypeOrDerived<CaseType, CastableBase>::value,
+                  "Switch case parameter is not a Castable pointer");
+    // Does the case match?
+    if (auto* ptr = As<CaseType>(object)) {
+      if constexpr (kHasReturnType) {
+        return first_case(ptr);
+      } else {
+        first_case(ptr);
+        return;
+      }
+    }
+    // Case did not match. Got any more cases to try?
+    if constexpr (sizeof...(other_cases) > 0) {
+      // Try the next cases...
+      if constexpr (kHasReturnType) {
+        auto res = Switch(object, std::forward<OTHER_CASES>(other_cases)...);
+        static_assert(std::is_same_v<decltype(res), ReturnType>,
+                      "Switch case types do not have consistent return type");
+        return res;
+      } else {
+        Switch(object, std::forward<OTHER_CASES>(other_cases)...);
+        return;
+      }
+    } else {
+      // That was the last case. No cases matched.
+      if constexpr (kHasReturnType) {
+        return {};
+      } else {
+        return;
+      }
+    }
+  }
 }  // namespace tint
diff --git a/src/ b/src/
new file mode 100644
index 0000000..839a932
--- /dev/null
+++ b/src/
@@ -0,0 +1,270 @@
+// Copyright 2022 The Tint Authors.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "bench/benchmark.h"
+namespace tint {
+namespace {
+struct Base : public tint::Castable<Base> {};
+struct A : public tint::Castable<A, Base> {};
+struct AA : public tint::Castable<AA, A> {};
+struct AAA : public tint::Castable<AAA, AA> {};
+struct AAB : public tint::Castable<AAB, AA> {};
+struct AAC : public tint::Castable<AAC, AA> {};
+struct AB : public tint::Castable<AB, A> {};
+struct ABA : public tint::Castable<ABA, AB> {};
+struct ABB : public tint::Castable<ABB, AB> {};
+struct ABC : public tint::Castable<ABC, AB> {};
+struct AC : public tint::Castable<AC, A> {};
+struct ACA : public tint::Castable<ACA, AC> {};
+struct ACB : public tint::Castable<ACB, AC> {};
+struct ACC : public tint::Castable<ACC, AC> {};
+struct B : public tint::Castable<B, Base> {};
+struct BA : public tint::Castable<BA, B> {};
+struct BAA : public tint::Castable<BAA, BA> {};
+struct BAB : public tint::Castable<BAB, BA> {};
+struct BAC : public tint::Castable<BAC, BA> {};
+struct BB : public tint::Castable<BB, B> {};
+struct BBA : public tint::Castable<BBA, BB> {};
+struct BBB : public tint::Castable<BBB, BB> {};
+struct BBC : public tint::Castable<BBC, BB> {};
+struct BC : public tint::Castable<BC, B> {};
+struct BCA : public tint::Castable<BCA, BC> {};
+struct BCB : public tint::Castable<BCB, BC> {};
+struct BCC : public tint::Castable<BCC, BC> {};
+struct C : public tint::Castable<C, Base> {};
+struct CA : public tint::Castable<CA, C> {};
+struct CAA : public tint::Castable<CAA, CA> {};
+struct CAB : public tint::Castable<CAB, CA> {};
+struct CAC : public tint::Castable<CAC, CA> {};
+struct CB : public tint::Castable<CB, C> {};
+struct CBA : public tint::Castable<CBA, CB> {};
+struct CBB : public tint::Castable<CBB, CB> {};
+struct CBC : public tint::Castable<CBC, CB> {};
+struct CC : public tint::Castable<CC, C> {};
+struct CCA : public tint::Castable<CCA, CC> {};
+struct CCB : public tint::Castable<CCB, CC> {};
+struct CCC : public tint::Castable<CCC, CC> {};
+using AllTypes = std::tuple<Base,
+                            A,
+                            AA,
+                            AAA,
+                            AAB,
+                            AAC,
+                            AB,
+                            ABA,
+                            ABB,
+                            ABC,
+                            AC,
+                            ACA,
+                            ACB,
+                            ACC,
+                            B,
+                            BA,
+                            BAA,
+                            BAB,
+                            BAC,
+                            BB,
+                            BBA,
+                            BBB,
+                            BBC,
+                            BC,
+                            BCA,
+                            BCB,
+                            BCC,
+                            C,
+                            CA,
+                            CAA,
+                            CAB,
+                            CAC,
+                            CB,
+                            CBA,
+                            CBB,
+                            CBC,
+                            CC,
+                            CCA,
+                            CCB,
+                            CCC>;
+std::vector<std::unique_ptr<Base>> MakeObjects() {
+  std::vector<std::unique_ptr<Base>> out;
+  out.emplace_back(std::make_unique<Base>());
+  out.emplace_back(std::make_unique<A>());
+  out.emplace_back(std::make_unique<AA>());
+  out.emplace_back(std::make_unique<AAA>());
+  out.emplace_back(std::make_unique<AAB>());
+  out.emplace_back(std::make_unique<AAC>());
+  out.emplace_back(std::make_unique<AB>());
+  out.emplace_back(std::make_unique<ABA>());
+  out.emplace_back(std::make_unique<ABB>());
+  out.emplace_back(std::make_unique<ABC>());
+  out.emplace_back(std::make_unique<AC>());
+  out.emplace_back(std::make_unique<ACA>());
+  out.emplace_back(std::make_unique<ACB>());
+  out.emplace_back(std::make_unique<ACC>());
+  out.emplace_back(std::make_unique<B>());
+  out.emplace_back(std::make_unique<BA>());
+  out.emplace_back(std::make_unique<BAA>());
+  out.emplace_back(std::make_unique<BAB>());
+  out.emplace_back(std::make_unique<BAC>());
+  out.emplace_back(std::make_unique<BB>());
+  out.emplace_back(std::make_unique<BBA>());
+  out.emplace_back(std::make_unique<BBB>());
+  out.emplace_back(std::make_unique<BBC>());
+  out.emplace_back(std::make_unique<BC>());
+  out.emplace_back(std::make_unique<BCA>());
+  out.emplace_back(std::make_unique<BCB>());
+  out.emplace_back(std::make_unique<BCC>());
+  out.emplace_back(std::make_unique<C>());
+  out.emplace_back(std::make_unique<CA>());
+  out.emplace_back(std::make_unique<CAA>());
+  out.emplace_back(std::make_unique<CAB>());
+  out.emplace_back(std::make_unique<CAC>());
+  out.emplace_back(std::make_unique<CB>());
+  out.emplace_back(std::make_unique<CBA>());
+  out.emplace_back(std::make_unique<CBB>());
+  out.emplace_back(std::make_unique<CBC>());
+  out.emplace_back(std::make_unique<CC>());
+  out.emplace_back(std::make_unique<CCA>());
+  out.emplace_back(std::make_unique<CCB>());
+  out.emplace_back(std::make_unique<CCC>());
+  return out;
+void CastableLargeSwitch(::benchmark::State& state) {
+  auto objects = MakeObjects();
+  size_t i = 0;
+  for (auto _ : state) {
+    auto* object = objects[i % objects.size()].get();
+    Switch(
+        object,  //
+        [&](const AAA*) { ::benchmark::DoNotOptimize(i += 40); },
+        [&](const AAB*) { ::benchmark::DoNotOptimize(i += 50); },
+        [&](const AAC*) { ::benchmark::DoNotOptimize(i += 60); },
+        [&](const ABA*) { ::benchmark::DoNotOptimize(i += 80); },
+        [&](const ABB*) { ::benchmark::DoNotOptimize(i += 90); },
+        [&](const ABC*) { ::benchmark::DoNotOptimize(i += 100); },
+        [&](const ACA*) { ::benchmark::DoNotOptimize(i += 120); },
+        [&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
+        [&](const ACC*) { ::benchmark::DoNotOptimize(i += 140); },
+        [&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
+        [&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
+        [&](const BAC*) { ::benchmark::DoNotOptimize(i += 190); },
+        [&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
+        [&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
+        [&](const BBC*) { ::benchmark::DoNotOptimize(i += 230); },
+        [&](const BCA*) { ::benchmark::DoNotOptimize(i += 250); },
+        [&](const BCB*) { ::benchmark::DoNotOptimize(i += 260); },
+        [&](const BCC*) { ::benchmark::DoNotOptimize(i += 270); },
+        [&](const CA*) { ::benchmark::DoNotOptimize(i += 290); },
+        [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
+        [&](const CAB*) { ::benchmark::DoNotOptimize(i += 310); },
+        [&](const CAC*) { ::benchmark::DoNotOptimize(i += 320); },
+        [&](const CBA*) { ::benchmark::DoNotOptimize(i += 340); },
+        [&](const CBB*) { ::benchmark::DoNotOptimize(i += 350); },
+        [&](const CBC*) { ::benchmark::DoNotOptimize(i += 360); },
+        [&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
+        [&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
+        [&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
+        [&](Default) { ::benchmark::DoNotOptimize(i += 123); });
+    i = (i * 31) ^ (i << 5);
+  }
+void CastableMediumSwitch(::benchmark::State& state) {
+  auto objects = MakeObjects();
+  size_t i = 0;
+  for (auto _ : state) {
+    auto* object = objects[i % objects.size()].get();
+    Switch(
+        object,  //
+        [&](const ACB*) { ::benchmark::DoNotOptimize(i += 130); },
+        [&](const BAA*) { ::benchmark::DoNotOptimize(i += 170); },
+        [&](const BAB*) { ::benchmark::DoNotOptimize(i += 180); },
+        [&](const BBA*) { ::benchmark::DoNotOptimize(i += 210); },
+        [&](const BBB*) { ::benchmark::DoNotOptimize(i += 220); },
+        [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); },
+        [&](const CCA*) { ::benchmark::DoNotOptimize(i += 380); },
+        [&](const CCB*) { ::benchmark::DoNotOptimize(i += 390); },
+        [&](const CCC*) { ::benchmark::DoNotOptimize(i += 400); },
+        [&](Default) { ::benchmark::DoNotOptimize(i += 123); });
+    i = (i * 31) ^ (i << 5);
+  }
+void CastableSmallSwitch(::benchmark::State& state) {
+  auto objects = MakeObjects();
+  size_t i = 0;
+  for (auto _ : state) {
+    auto* object = objects[i % objects.size()].get();
+    Switch(
+        object,  //
+        [&](const AAB*) { ::benchmark::DoNotOptimize(i += 30); },
+        [&](const CAC*) { ::benchmark::DoNotOptimize(i += 290); },
+        [&](const CAA*) { ::benchmark::DoNotOptimize(i += 300); });
+    i = (i * 31) ^ (i << 5);
+  }
+}  // namespace
+}  // namespace tint
diff --git a/src/ b/src/
index e44983b..2a9a71a 100644
--- a/src/
+++ b/src/
@@ -252,6 +252,151 @@
   ASSERT_EQ(gecko->As<Reptile>(), static_cast<Reptile*>(gecko.get()));
+TEST(Castable, SwitchNoDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    bool frog_matched_amphibian = false;
+    Switch(
+        frog.get(),  //
+        [&](Reptile*) { FAIL() << "frog is not reptile"; },
+        [&](Mammal*) { FAIL() << "frog is not mammal"; },
+        [&](Amphibian* amphibian) {
+          EXPECT_EQ(amphibian, frog.get());
+          frog_matched_amphibian = true;
+        });
+    EXPECT_TRUE(frog_matched_amphibian);
+  }
+  {
+    bool bear_matched_mammal = false;
+    Switch(
+        bear.get(),  //
+        [&](Reptile*) { FAIL() << "bear is not reptile"; },
+        [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+        [&](Mammal* mammal) {
+          EXPECT_EQ(mammal, bear.get());
+          bear_matched_mammal = true;
+        });
+    EXPECT_TRUE(bear_matched_mammal);
+  }
+  {
+    bool gecko_matched_reptile = false;
+    Switch(
+        gecko.get(),  //
+        [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+        [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+        [&](Reptile* reptile) {
+          EXPECT_EQ(reptile, gecko.get());
+          gecko_matched_reptile = true;
+        });
+    EXPECT_TRUE(gecko_matched_reptile);
+  }
+TEST(Castable, SwitchWithUnusedDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    bool frog_matched_amphibian = false;
+    Switch(
+        frog.get(),  //
+        [&](Reptile*) { FAIL() << "frog is not reptile"; },
+        [&](Mammal*) { FAIL() << "frog is not mammal"; },
+        [&](Amphibian* amphibian) {
+          EXPECT_EQ(amphibian, frog.get());
+          frog_matched_amphibian = true;
+        },
+        [&](Default) { FAIL() << "default should not have been selected"; });
+    EXPECT_TRUE(frog_matched_amphibian);
+  }
+  {
+    bool bear_matched_mammal = false;
+    Switch(
+        bear.get(),  //
+        [&](Reptile*) { FAIL() << "bear is not reptile"; },
+        [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+        [&](Mammal* mammal) {
+          EXPECT_EQ(mammal, bear.get());
+          bear_matched_mammal = true;
+        },
+        [&](Default) { FAIL() << "default should not have been selected"; });
+    EXPECT_TRUE(bear_matched_mammal);
+  }
+  {
+    bool gecko_matched_reptile = false;
+    Switch(
+        gecko.get(),  //
+        [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+        [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+        [&](Reptile* reptile) {
+          EXPECT_EQ(reptile, gecko.get());
+          gecko_matched_reptile = true;
+        },
+        [&](Default) { FAIL() << "default should not have been selected"; });
+    EXPECT_TRUE(gecko_matched_reptile);
+  }
+TEST(Castable, SwitchDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    bool frog_matched_default = false;
+    Switch(
+        frog.get(),  //
+        [&](Reptile*) { FAIL() << "frog is not reptile"; },
+        [&](Mammal*) { FAIL() << "frog is not mammal"; },
+        [&](Default) { frog_matched_default = true; });
+    EXPECT_TRUE(frog_matched_default);
+  }
+  {
+    bool bear_matched_default = false;
+    Switch(
+        bear.get(),  //
+        [&](Reptile*) { FAIL() << "bear is not reptile"; },
+        [&](Amphibian*) { FAIL() << "bear is not amphibian"; },
+        [&](Default) { bear_matched_default = true; });
+    EXPECT_TRUE(bear_matched_default);
+  }
+  {
+    bool gecko_matched_default = false;
+    Switch(
+        gecko.get(),  //
+        [&](Mammal*) { FAIL() << "gecko is not mammal"; },
+        [&](Amphibian*) { FAIL() << "gecko is not amphibian"; },
+        [&](Default) { gecko_matched_default = true; });
+    EXPECT_TRUE(gecko_matched_default);
+  }
+TEST(Castable, SwitchMatchFirst) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  {
+    bool frog_matched_animal = false;
+    Switch(
+        frog.get(),
+        [&](Animal* animal) {
+          EXPECT_EQ(animal, frog.get());
+          frog_matched_animal = true;
+        },
+        [&](Amphibian*) { FAIL() << "animal should have been matched first"; });
+    EXPECT_TRUE(frog_matched_animal);
+  }
+  {
+    bool frog_matched_amphibian = false;
+    Switch(
+        frog.get(),
+        [&](Amphibian* amphibain) {
+          EXPECT_EQ(amphibain, frog.get());
+          frog_matched_amphibian = true;
+        },
+        [&](Animal*) { FAIL() << "amphibian should have been matched first"; });
+    EXPECT_TRUE(frog_matched_amphibian);
+  }
 }  // namespace
diff --git a/src/reader/spirv/ b/src/reader/spirv/
index b88916e..ed35ca4 100644
--- a/src/reader/spirv/
+++ b/src/reader/spirv/
@@ -953,7 +953,7 @@
 bool FunctionEmitter::EmitPipelineInput(std::string var_name,
                                         const Type* var_type,
-                                        ast::AttributeList* decos,
+                                        ast::AttributeList* attrs,
                                         std::vector<int> index_prefix,
                                         const Type* tip_type,
                                         const Type* forced_param_type,
@@ -966,105 +966,121 @@
   // Recursively flatten matrices, arrays, and structures.
-  if (auto* matrix_type = tip_type->As<Matrix>()) {
-    index_prefix.push_back(0);
-    const auto num_columns = static_cast<int>(matrix_type->columns);
-    const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
-    for (int col = 0; col < num_columns; col++) {
-      index_prefix.back() = col;
-      if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, vec_ty,
-                             forced_param_type, params, statements)) {
-        return false;
-      }
-    }
-    return success();
-  } else if (auto* array_type = tip_type->As<Array>()) {
-    if (array_type->size == 0) {
-      return Fail() << "runtime-size array not allowed on pipeline IO";
-    }
-    index_prefix.push_back(0);
-    const Type* elem_ty = array_type->type;
-    for (int i = 0; i < static_cast<int>(array_type->size); i++) {
-      index_prefix.back() = i;
-      if (!EmitPipelineInput(var_name, var_type, decos, index_prefix, elem_ty,
-                             forced_param_type, params, statements)) {
-        return false;
-      }
-    }
-    return success();
-  } else if (auto* struct_type = tip_type->As<Struct>()) {
-    const auto& members = struct_type->members;
-    index_prefix.push_back(0);
-    for (int i = 0; i < static_cast<int>(members.size()); ++i) {
-      index_prefix.back() = i;
-      ast::AttributeList member_decos(*decos);
-      if (!parser_impl_.ConvertPipelineDecorations(
-              struct_type,
-              parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
-              &member_decos)) {
-        return false;
-      }
-      if (!EmitPipelineInput(var_name, var_type, &member_decos, index_prefix,
-                             members[i], forced_param_type, params,
-                             statements)) {
-        return false;
-      }
-      // Copy the location as updated by nested expansion of the member.
-      parser_impl_.SetLocation(decos, GetLocation(member_decos));
-    }
-    return success();
-  }
+  return Switch(
+      tip_type,
+      [&](const Matrix* matrix_type) -> bool {
+        index_prefix.push_back(0);
+        const auto num_columns = static_cast<int>(matrix_type->columns);
+        const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
+        for (int col = 0; col < num_columns; col++) {
+          index_prefix.back() = col;
+          if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
+                                 vec_ty, forced_param_type, params,
+                                 statements)) {
+            return false;
+          }
+        }
+        return success();
+      },
+      [&](const Array* array_type) -> bool {
+        if (array_type->size == 0) {
+          return Fail() << "runtime-size array not allowed on pipeline IO";
+        }
+        index_prefix.push_back(0);
+        const Type* elem_ty = array_type->type;
+        for (int i = 0; i < static_cast<int>(array_type->size); i++) {
+          index_prefix.back() = i;
+          if (!EmitPipelineInput(var_name, var_type, attrs, index_prefix,
+                                 elem_ty, forced_param_type, params,
+                                 statements)) {
+            return false;
+          }
+        }
+        return success();
+      },
+      [&](const Struct* struct_type) -> bool {
+        const auto& members = struct_type->members;
+        index_prefix.push_back(0);
+        for (int i = 0; i < static_cast<int>(members.size()); ++i) {
+          index_prefix.back() = i;
+          ast::AttributeList member_attrs(*attrs);
+          if (!parser_impl_.ConvertPipelineDecorations(
+                  struct_type,
+                  parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
+                  &member_attrs)) {
+            return false;
+          }
+          if (!EmitPipelineInput(var_name, var_type, &member_attrs,
+                                 index_prefix, members[i], forced_param_type,
+                                 params, statements)) {
+            return false;
+          }
+          // Copy the location as updated by nested expansion of the member.
+          parser_impl_.SetLocation(attrs, GetLocation(member_attrs));
+        }
+        return success();
+      },
+      [&](Default) {
+        const bool is_builtin =
+            ast::HasAttribute<ast::BuiltinAttribute>(*attrs);
-  const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
+        const Type* param_type = is_builtin ? forced_param_type : tip_type;
-  const Type* param_type = is_builtin ? forced_param_type : tip_type;
+        const auto param_name = namer_.MakeDerivedName(var_name + "_param");
+        // Create the parameter.
+        // TODO(dneto): Note: If the parameter has non-location decorations,
+        // then those decoration AST nodes will be reused between multiple
+        // elements of a matrix, array, or structure.  Normally that's
+        // disallowed but currently the SPIR-V reader will make duplicates when
+        // the entire AST is cloned at the top level of the SPIR-V reader flow.
+        // Consider rewriting this to avoid this node-sharing.
+        params->push_back(
+            builder_.Param(param_name, param_type->Build(builder_), *attrs));
-  const auto param_name = namer_.MakeDerivedName(var_name + "_param");
-  // Create the parameter.
-  // TODO(dneto): Note: If the parameter has non-location decorations,
-  // then those decoration AST nodes will be reused between multiple elements
-  // of a matrix, array, or structure.  Normally that's disallowed but currently
-  // the SPIR-V reader will make duplicates when the entire AST is cloned
-  // at the top level of the SPIR-V reader flow.  Consider rewriting this
-  // to avoid this node-sharing.
-  params->push_back(
-      builder_.Param(param_name, param_type->Build(builder_), *decos));
+        // Add a body statement to copy the parameter to the corresponding
+        // private variable.
+        const ast::Expression* param_value = builder_.Expr(param_name);
+        const ast::Expression* store_dest = builder_.Expr(var_name);
-  // Add a body statement to copy the parameter to the corresponding private
-  // variable.
-  const ast::Expression* param_value = builder_.Expr(param_name);
-  const ast::Expression* store_dest = builder_.Expr(var_name);
+        // Index into the LHS as needed.
+        auto* current_type =
+            var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
+        for (auto index : index_prefix) {
+          Switch(
+              current_type,
+              [&](const Matrix* matrix_type) {
+                store_dest =
+                    builder_.IndexAccessor(store_dest, builder_.Expr(index));
+                current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
+              },
+              [&](const Array* array_type) {
+                store_dest =
+                    builder_.IndexAccessor(store_dest, builder_.Expr(index));
+                current_type = array_type->type->UnwrapAlias();
+              },
+              [&](const Struct* struct_type) {
+                store_dest = builder_.MemberAccessor(
+                    store_dest, builder_.Expr(parser_impl_.GetMemberName(
+                                    *struct_type, index)));
+                current_type = struct_type->members[index];
+              });
+        }
-  // Index into the LHS as needed.
-  auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
-  for (auto index : index_prefix) {
-    if (auto* matrix_type = current_type->As<Matrix>()) {
-      store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
-      current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
-    } else if (auto* array_type = current_type->As<Array>()) {
-      store_dest = builder_.IndexAccessor(store_dest, builder_.Expr(index));
-      current_type = array_type->type->UnwrapAlias();
-    } else if (auto* struct_type = current_type->As<Struct>()) {
-      store_dest = builder_.MemberAccessor(
-          store_dest,
-          builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
-      current_type = struct_type->members[index];
-    }
-  }
+        if (is_builtin && (tip_type != forced_param_type)) {
+          // The parameter will have the WGSL type, but we need bitcast to
+          // the variable store type.
+          param_value = create<ast::BitcastExpression>(
+              tip_type->Build(builder_), param_value);
+        }
-  if (is_builtin && (tip_type != forced_param_type)) {
-    // The parameter will have the WGSL type, but we need bitcast to
-    // the variable store type.
-    param_value =
-        create<ast::BitcastExpression>(tip_type->Build(builder_), param_value);
-  }
+        statements->push_back(builder_.Assign(store_dest, param_value));
-  statements->push_back(builder_.Assign(store_dest, param_value));
+        // Increment the location attribute, in case more parameters will
+        // follow.
+        IncrementLocation(attrs);
-  // Increment the location attribute, in case more parameters will follow.
-  IncrementLocation(decos);
-  return success();
+        return success();
+      });
 void FunctionEmitter::IncrementLocation(ast::AttributeList* attributes) {
@@ -1102,106 +1118,120 @@
   // Recursively flatten matrices, arrays, and structures.
-  if (auto* matrix_type = tip_type->As<Matrix>()) {
-    index_prefix.push_back(0);
-    const auto num_columns = static_cast<int>(matrix_type->columns);
-    const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
-    for (int col = 0; col < num_columns; col++) {
-      index_prefix.back() = col;
-      if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, vec_ty,
-                              forced_member_type, return_members,
-                              return_exprs)) {
-        return false;
-      }
-    }
-    return success();
-  } else if (auto* array_type = tip_type->As<Array>()) {
-    if (array_type->size == 0) {
-      return Fail() << "runtime-size array not allowed on pipeline IO";
-    }
-    index_prefix.push_back(0);
-    const Type* elem_ty = array_type->type;
-    for (int i = 0; i < static_cast<int>(array_type->size); i++) {
-      index_prefix.back() = i;
-      if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix, elem_ty,
-                              forced_member_type, return_members,
-                              return_exprs)) {
-        return false;
-      }
-    }
-    return success();
-  } else if (auto* struct_type = tip_type->As<Struct>()) {
-    const auto& members = struct_type->members;
-    index_prefix.push_back(0);
-    for (int i = 0; i < static_cast<int>(members.size()); ++i) {
-      index_prefix.back() = i;
-      ast::AttributeList member_decos(*decos);
-      if (!parser_impl_.ConvertPipelineDecorations(
-              struct_type,
-              parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
-              &member_decos)) {
-        return false;
-      }
-      if (!EmitPipelineOutput(var_name, var_type, &member_decos, index_prefix,
-                              members[i], forced_member_type, return_members,
-                              return_exprs)) {
-        return false;
-      }
-      // Copy the location as updated by nested expansion of the member.
-      parser_impl_.SetLocation(decos, GetLocation(member_decos));
-    }
-    return success();
-  }
+  return Switch(
+      tip_type,
+      [&](const Matrix* matrix_type) -> bool {
+        index_prefix.push_back(0);
+        const auto num_columns = static_cast<int>(matrix_type->columns);
+        const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
+        for (int col = 0; col < num_columns; col++) {
+          index_prefix.back() = col;
+          if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
+                                  vec_ty, forced_member_type, return_members,
+                                  return_exprs)) {
+            return false;
+          }
+        }
+        return success();
+      },
+      [&](const Array* array_type) -> bool {
+        if (array_type->size == 0) {
+          return Fail() << "runtime-size array not allowed on pipeline IO";
+        }
+        index_prefix.push_back(0);
+        const Type* elem_ty = array_type->type;
+        for (int i = 0; i < static_cast<int>(array_type->size); i++) {
+          index_prefix.back() = i;
+          if (!EmitPipelineOutput(var_name, var_type, decos, index_prefix,
+                                  elem_ty, forced_member_type, return_members,
+                                  return_exprs)) {
+            return false;
+          }
+        }
+        return success();
+      },
+      [&](const Struct* struct_type) -> bool {
+        const auto& members = struct_type->members;
+        index_prefix.push_back(0);
+        for (int i = 0; i < static_cast<int>(members.size()); ++i) {
+          index_prefix.back() = i;
+          ast::AttributeList member_attrs(*decos);
+          if (!parser_impl_.ConvertPipelineDecorations(
+                  struct_type,
+                  parser_impl_.GetMemberPipelineDecorations(*struct_type, i),
+                  &member_attrs)) {
+            return false;
+          }
+          if (!EmitPipelineOutput(var_name, var_type, &member_attrs,
+                                  index_prefix, members[i], forced_member_type,
+                                  return_members, return_exprs)) {
+            return false;
+          }
+          // Copy the location as updated by nested expansion of the member.
+          parser_impl_.SetLocation(decos, GetLocation(member_attrs));
+        }
+        return success();
+      },
+      [&](Default) {
+        const bool is_builtin =
+            ast::HasAttribute<ast::BuiltinAttribute>(*decos);
-  const bool is_builtin = ast::HasAttribute<ast::BuiltinAttribute>(*decos);
+        const Type* member_type = is_builtin ? forced_member_type : tip_type;
+        // Derive the member name directly from the variable name.  They can't
+        // collide.
+        const auto member_name = namer_.MakeDerivedName(var_name);
+        // Create the member.
+        // TODO(dneto): Note: If the parameter has non-location decorations,
+        // then those decoration AST nodes  will be reused between multiple
+        // elements of a matrix, array, or structure.  Normally that's
+        // disallowed but currently the SPIR-V reader will make duplicates when
+        // the entire AST is cloned at the top level of the SPIR-V reader flow.
+        // Consider rewriting this to avoid this node-sharing.
+        return_members->push_back(
+            builder_.Member(member_name, member_type->Build(builder_), *decos));
-  const Type* member_type = is_builtin ? forced_member_type : tip_type;
-  // Derive the member name directly from the variable name.  They can't
-  // collide.
-  const auto member_name = namer_.MakeDerivedName(var_name);
-  // Create the member.
-  // TODO(dneto): Note: If the parameter has non-location decorations,
-  // then those decoration AST nodes  will be reused between multiple elements
-  // of a matrix, array, or structure.  Normally that's disallowed but currently
-  // the SPIR-V reader will make duplicates when the entire AST is cloned
-  // at the top level of the SPIR-V reader flow.  Consider rewriting this
-  // to avoid this node-sharing.
-  return_members->push_back(
-      builder_.Member(member_name, member_type->Build(builder_), *decos));
+        // Create an expression to evaluate the part of the variable indexed by
+        // the index_prefix.
+        const ast::Expression* load_source = builder_.Expr(var_name);
-  // Create an expression to evaluate the part of the variable indexed by
-  // the index_prefix.
-  const ast::Expression* load_source = builder_.Expr(var_name);
+        // Index into the variable as needed to pick out the flattened member.
+        auto* current_type =
+            var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
+        for (auto index : index_prefix) {
+          Switch(
+              current_type,
+              [&](const Matrix* matrix_type) {
+                load_source =
+                    builder_.IndexAccessor(load_source, builder_.Expr(index));
+                current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
+              },
+              [&](const Array* array_type) {
+                load_source =
+                    builder_.IndexAccessor(load_source, builder_.Expr(index));
+                current_type = array_type->type->UnwrapAlias();
+              },
+              [&](const Struct* struct_type) {
+                load_source = builder_.MemberAccessor(
+                    load_source, builder_.Expr(parser_impl_.GetMemberName(
+                                     *struct_type, index)));
+                current_type = struct_type->members[index];
+              });
+        }
-  // Index into the variable as needed to pick out the flattened member.
-  auto* current_type = var_type->UnwrapAlias()->UnwrapRef()->UnwrapAlias();
-  for (auto index : index_prefix) {
-    if (auto* matrix_type = current_type->As<Matrix>()) {
-      load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
-      current_type = ty_.Vector(matrix_type->type, matrix_type->rows);
-    } else if (auto* array_type = current_type->As<Array>()) {
-      load_source = builder_.IndexAccessor(load_source, builder_.Expr(index));
-      current_type = array_type->type->UnwrapAlias();
-    } else if (auto* struct_type = current_type->As<Struct>()) {
-      load_source = builder_.MemberAccessor(
-          load_source,
-          builder_.Expr(parser_impl_.GetMemberName(*struct_type, index)));
-      current_type = struct_type->members[index];
-    }
-  }
+        if (is_builtin && (tip_type != forced_member_type)) {
+          // The member will have the WGSL type, but we need bitcast to
+          // the variable store type.
+          load_source = create<ast::BitcastExpression>(
+              forced_member_type->Build(builder_), load_source);
+        }
+        return_exprs->push_back(load_source);
-  if (is_builtin && (tip_type != forced_member_type)) {
-    // The member will have the WGSL type, but we need bitcast to
-    // the variable store type.
-    load_source = create<ast::BitcastExpression>(
-        forced_member_type->Build(builder_), load_source);
-  }
-  return_exprs->push_back(load_source);
+        // Increment the location attribute, in case more parameters will
+        // follow.
+        IncrementLocation(decos);
-  // Increment the location attribute, in case more parameters will follow.
-  IncrementLocation(decos);
-  return success();
+        return success();
+      });
 bool FunctionEmitter::EmitEntryPointAsWrapper() {
diff --git a/src/writer/hlsl/ b/src/writer/hlsl/
index 6d20ace..8502427 100644
--- a/src/writer/hlsl/
+++ b/src/writer/hlsl/
@@ -239,39 +239,41 @@
     last_kind = kind;
-    if (auto* global = decl->As<ast::Variable>()) {
-      if (!EmitGlobalVariable(global)) {
-        return false;
-      }
-    } else if (auto* str = decl->As<ast::Struct>()) {
-      auto* ty = builder_.Sem().Get(str);
-      auto storage_class_uses = ty->StorageClassUsage();
-      if (storage_class_uses.size() !=
-          (storage_class_uses.count(ast::StorageClass::kStorage) +
-           storage_class_uses.count(ast::StorageClass::kUniform))) {
-        // The structure is used as something other than a storage buffer or
-        // uniform buffer, so it needs to be emitted.
-        // Storage buffer are read and written to via a ByteAddressBuffer
-        // instead of true structure.
-        // Structures used as uniform buffer are read from an array of vectors
-        // instead of true structure.
-        if (!EmitStructType(current_buffer_, ty)) {
+    bool ok = Switch(
+        decl,
+        [&](const ast::Variable* global) {  //
+          return EmitGlobalVariable(global);
+        },
+        [&](const ast::Struct* str) {
+          auto* ty = builder_.Sem().Get(str);
+          auto storage_class_uses = ty->StorageClassUsage();
+          if (storage_class_uses.size() !=
+              (storage_class_uses.count(ast::StorageClass::kStorage) +
+               storage_class_uses.count(ast::StorageClass::kUniform))) {
+            // The structure is used as something other than a storage buffer or
+            // uniform buffer, so it needs to be emitted.
+            // Storage buffer are read and written to via a ByteAddressBuffer
+            // instead of true structure.
+            // Structures used as uniform buffer are read from an array of
+            // vectors instead of true structure.
+            return EmitStructType(current_buffer_, ty);
+          }
+          return true;
+        },
+        [&](const ast::Function* func) {
+          if (func->IsEntryPoint()) {
+            return EmitEntryPointFunction(func);
+          }
+          return EmitFunction(func);
+        },
+        [&](Default) {
+          TINT_ICE(Writer, diagnostics_)
+              << "unhandled module-scope declaration: "
+              << decl->TypeInfo().name;
           return false;
-        }
-      }
-    } else if (auto* func = decl->As<ast::Function>()) {
-      if (func->IsEntryPoint()) {
-        if (!EmitEntryPointFunction(func)) {
-          return false;
-        }
-      } else {
-        if (!EmitFunction(func)) {
-          return false;
-        }
-      }
-    } else {
-      TINT_ICE(Writer, diagnostics_)
-          << "unhandled module-scope declaration: " << decl->TypeInfo().name;
+        });
+    if (!ok) {
       return false;
@@ -929,22 +931,25 @@
                              const ast::CallExpression* expr) {
   auto* call = builder_.Sem().Get(expr);
   auto* target = call->Target();
-  if (auto* func = target->As<sem::Function>()) {
-    return EmitFunctionCall(out, call, func);
-  }
-  if (auto* builtin = target->As<sem::Builtin>()) {
-    return EmitBuiltinCall(out, call, builtin);
-  }
-  if (auto* conv = target->As<sem::TypeConversion>()) {
-    return EmitTypeConversion(out, call, conv);
-  }
-  if (auto* ctor = target->As<sem::TypeConstructor>()) {
-    return EmitTypeConstructor(out, call, ctor);
-  }
-  TINT_ICE(Writer, diagnostics_)
-      << "unhandled call target: " << target->TypeInfo().name;
-  return false;
+  return Switch(
+      target,
+      [&](const sem::Function* func) {
+        return EmitFunctionCall(out, call, func);
+      },
+      [&](const sem::Builtin* builtin) {
+        return EmitBuiltinCall(out, call, builtin);
+      },
+      [&](const sem::TypeConversion* conv) {
+        return EmitTypeConversion(out, call, conv);
+      },
+      [&](const sem::TypeConstructor* ctor) {
+        return EmitTypeConstructor(out, call, ctor);
+      },
+      [&](Default) {
+        TINT_ICE(Writer, diagnostics_)
+            << "unhandled call target: " << target->TypeInfo().name;
+        return false;
+      });
 bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
@@ -2639,35 +2644,38 @@
 bool GeneratorImpl::EmitExpression(std::ostream& out,
                                    const ast::Expression* expr) {
-  if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
-    return EmitIndexAccessor(out, a);
-  }
-  if (auto* b = expr->As<ast::BinaryExpression>()) {
-    return EmitBinary(out, b);
-  }
-  if (auto* b = expr->As<ast::BitcastExpression>()) {
-    return EmitBitcast(out, b);
-  }
-  if (auto* c = expr->As<ast::CallExpression>()) {
-    return EmitCall(out, c);
-  }
-  if (auto* i = expr->As<ast::IdentifierExpression>()) {
-    return EmitIdentifier(out, i);
-  }
-  if (auto* l = expr->As<ast::LiteralExpression>()) {
-    return EmitLiteral(out, l);
-  }
-  if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
-    return EmitMemberAccessor(out, m);
-  }
-  if (auto* u = expr->As<ast::UnaryOpExpression>()) {
-    return EmitUnaryOp(out, u);
-  }
-  diagnostics_.add_error(
-      diag::System::Writer,
-      "unknown expression type: " + std::string(expr->TypeInfo().name));
-  return false;
+  return Switch(
+      expr,
+      [&](const ast::IndexAccessorExpression* a) {  //
+        return EmitIndexAccessor(out, a);
+      },
+      [&](const ast::BinaryExpression* b) {  //
+        return EmitBinary(out, b);
+      },
+      [&](const ast::BitcastExpression* b) {  //
+        return EmitBitcast(out, b);
+      },
+      [&](const ast::CallExpression* c) {  //
+        return EmitCall(out, c);
+      },
+      [&](const ast::IdentifierExpression* i) {  //
+        return EmitIdentifier(out, i);
+      },
+      [&](const ast::LiteralExpression* l) {  //
+        return EmitLiteral(out, l);
+      },
+      [&](const ast::MemberAccessorExpression* m) {  //
+        return EmitMemberAccessor(out, m);
+      },
+      [&](const ast::UnaryOpExpression* u) {  //
+        return EmitUnaryOp(out, u);
+      },
+      [&](Default) {  //
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "unknown expression type: " + std::string(expr->TypeInfo().name));
+        return false;
+      });
 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@@ -3127,80 +3135,108 @@
 bool GeneratorImpl::EmitLiteral(std::ostream& out,
                                 const ast::LiteralExpression* lit) {
-  if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
-    out << (l->value ? "true" : "false");
-  } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
-    if (std::isinf(fl->value)) {
-      out << (fl->value >= 0 ? "asfloat(0x7f800000u)" : "asfloat(0xff800000u)");
-    } else if (std::isnan(fl->value)) {
-      out << "asfloat(0x7fc00000u)";
-    } else {
-      out << FloatToString(fl->value) << "f";
-    }
-  } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
-    out << sl->value;
-  } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
-    out << ul->value << "u";
-  } else {
-    diagnostics_.add_error(diag::System::Writer, "unknown literal type");
-    return false;
-  }
-  return true;
+  return Switch(
+      lit,
+      [&](const ast::BoolLiteralExpression* l) {
+        out << (l->value ? "true" : "false");
+        return true;
+      },
+      [&](const ast::FloatLiteralExpression* fl) {
+        if (std::isinf(fl->value)) {
+          out << (fl->value >= 0 ? "asfloat(0x7f800000u)"
+                                 : "asfloat(0xff800000u)");
+        } else if (std::isnan(fl->value)) {
+          out << "asfloat(0x7fc00000u)";
+        } else {
+          out << FloatToString(fl->value) << "f";
+        }
+        return true;
+      },
+      [&](const ast::SintLiteralExpression* sl) {
+        out << sl->value;
+        return true;
+      },
+      [&](const ast::UintLiteralExpression* ul) {
+        out << ul->value << "u";
+        return true;
+      },
+      [&](Default) {
+        diagnostics_.add_error(diag::System::Writer, "unknown literal type");
+        return false;
+      });
 bool GeneratorImpl::EmitValue(std::ostream& out,
                               const sem::Type* type,
                               int value) {
-  if (type->Is<sem::Bool>()) {
-    out << (value == 0 ? "false" : "true");
-  } else if (type->Is<sem::F32>()) {
-    out << value << ".0f";
-  } else if (type->Is<sem::I32>()) {
-    out << value;
-  } else if (type->Is<sem::U32>()) {
-    out << value << "u";
-  } else if (auto* vec = type->As<sem::Vector>()) {
-    if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
-                  "")) {
-      return false;
-    }
-    ScopedParen sp(out);
-    for (uint32_t i = 0; i < vec->Width(); i++) {
-      if (i != 0) {
-        out << ", ";
-      }
-      if (!EmitValue(out, vec->type(), value)) {
+  return Switch(
+      type,
+      [&](const sem::Bool*) {
+        out << (value == 0 ? "false" : "true");
+        return true;
+      },
+      [&](const sem::F32*) {
+        out << value << ".0f";
+        return true;
+      },
+      [&](const sem::I32*) {
+        out << value;
+        return true;
+      },
+      [&](const sem::U32*) {
+        out << value << "u";
+        return true;
+      },
+      [&](const sem::Vector* vec) {
+        if (!EmitType(out, type, ast::StorageClass::kNone,
+                      ast::Access::kReadWrite, "")) {
+          return false;
+        }
+        ScopedParen sp(out);
+        for (uint32_t i = 0; i < vec->Width(); i++) {
+          if (i != 0) {
+            out << ", ";
+          }
+          if (!EmitValue(out, vec->type(), value)) {
+            return false;
+          }
+        }
+        return true;
+      },
+      [&](const sem::Matrix* mat) {
+        if (!EmitType(out, type, ast::StorageClass::kNone,
+                      ast::Access::kReadWrite, "")) {
+          return false;
+        }
+        ScopedParen sp(out);
+        for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
+          if (i != 0) {
+            out << ", ";
+          }
+          if (!EmitValue(out, mat->type(), value)) {
+            return false;
+          }
+        }
+        return true;
+      },
+      [&](const sem::Struct*) {
+        out << "(";
+        TINT_DEFER(out << ")" << value);
+        return EmitType(out, type, ast::StorageClass::kNone,
+                        ast::Access::kUndefined, "");
+      },
+      [&](const sem::Array*) {
+        out << "(";
+        TINT_DEFER(out << ")" << value);
+        return EmitType(out, type, ast::StorageClass::kNone,
+                        ast::Access::kUndefined, "");
+      },
+      [&](Default) {
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "Invalid type for value emission: " + type->type_name());
         return false;
-      }
-    }
-  } else if (auto* mat = type->As<sem::Matrix>()) {
-    if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
-                  "")) {
-      return false;
-    }
-    ScopedParen sp(out);
-    for (uint32_t i = 0; i < (mat->rows() * mat->columns()); i++) {
-      if (i != 0) {
-        out << ", ";
-      }
-      if (!EmitValue(out, mat->type(), value)) {
-        return false;
-      }
-    }
-  } else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
-    out << "(";
-    if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
-                  "")) {
-      return false;
-    }
-    out << ")" << value;
-  } else {
-    diagnostics_.add_error(
-        diag::System::Writer,
-        "Invalid type for value emission: " + type->type_name());
-    return false;
-  }
-  return true;
+      });
 bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
@@ -3375,56 +3411,59 @@
 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
-  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
-    return EmitAssign(a);
-  }
-  if (auto* b = stmt->As<ast::BlockStatement>()) {
-    return EmitBlock(b);
-  }
-  if (auto* b = stmt->As<ast::BreakStatement>()) {
-    return EmitBreak(b);
-  }
-  if (auto* c = stmt->As<ast::CallStatement>()) {
-    auto out = line();
-    if (!EmitCall(out, c->expr)) {
-      return false;
-    }
-    out << ";";
-    return true;
-  }
-  if (auto* c = stmt->As<ast::ContinueStatement>()) {
-    return EmitContinue(c);
-  }
-  if (auto* d = stmt->As<ast::DiscardStatement>()) {
-    return EmitDiscard(d);
-  }
-  if (stmt->As<ast::FallthroughStatement>()) {
-    line() << "/* fallthrough */";
-    return true;
-  }
-  if (auto* i = stmt->As<ast::IfStatement>()) {
-    return EmitIf(i);
-  }
-  if (auto* l = stmt->As<ast::LoopStatement>()) {
-    return EmitLoop(l);
-  }
-  if (auto* l = stmt->As<ast::ForLoopStatement>()) {
-    return EmitForLoop(l);
-  }
-  if (auto* r = stmt->As<ast::ReturnStatement>()) {
-    return EmitReturn(r);
-  }
-  if (auto* s = stmt->As<ast::SwitchStatement>()) {
-    return EmitSwitch(s);
-  }
-  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
-    return EmitVariable(v->variable);
-  }
-  diagnostics_.add_error(
-      diag::System::Writer,
-      "unknown statement type: " + std::string(stmt->TypeInfo().name));
-  return false;
+  return Switch(
+      stmt,
+      [&](const ast::AssignmentStatement* a) {  //
+        return EmitAssign(a);
+      },
+      [&](const ast::BlockStatement* b) {  //
+        return EmitBlock(b);
+      },
+      [&](const ast::BreakStatement* b) {  //
+        return EmitBreak(b);
+      },
+      [&](const ast::CallStatement* c) {  //
+        auto out = line();
+        if (!EmitCall(out, c->expr)) {
+          return false;
+        }
+        out << ";";
+        return true;
+      },
+      [&](const ast::ContinueStatement* c) {  //
+        return EmitContinue(c);
+      },
+      [&](const ast::DiscardStatement* d) {  //
+        return EmitDiscard(d);
+      },
+      [&](const ast::FallthroughStatement*) {  //
+        line() << "/* fallthrough */";
+        return true;
+      },
+      [&](const ast::IfStatement* i) {  //
+        return EmitIf(i);
+      },
+      [&](const ast::LoopStatement* l) {  //
+        return EmitLoop(l);
+      },
+      [&](const ast::ForLoopStatement* l) {  //
+        return EmitForLoop(l);
+      },
+      [&](const ast::ReturnStatement* r) {  //
+        return EmitReturn(r);
+      },
+      [&](const ast::SwitchStatement* s) {  //
+        return EmitSwitch(s);
+      },
+      [&](const ast::VariableDeclStatement* v) {  //
+        return EmitVariable(v->variable);
+      },
+      [&](Default) {  //
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "unknown statement type: " + std::string(stmt->TypeInfo().name));
+        return false;
+      });
 bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
@@ -3516,156 +3555,181 @@
-  if (auto* ary = type->As<sem::Array>()) {
-    const sem::Type* base_type = ary;
-    std::vector<uint32_t> sizes;
-    while (auto* arr = base_type->As<sem::Array>()) {
-      if (arr->IsRuntimeSized()) {
+  return Switch(
+      type,
+      [&](const sem::Array* ary) {
+        const sem::Type* base_type = ary;
+        std::vector<uint32_t> sizes;
+        while (auto* arr = base_type->As<sem::Array>()) {
+          if (arr->IsRuntimeSized()) {
+            TINT_ICE(Writer, diagnostics_)
+                << "Runtime arrays may only exist in storage buffers, which "
+                   "should "
+                   "have been transformed into a ByteAddressBuffer";
+            return false;
+          }
+          sizes.push_back(arr->Count());
+          base_type = arr->ElemType();
+        }
+        if (!EmitType(out, base_type, storage_class, access, "")) {
+          return false;
+        }
+        if (!name.empty()) {
+          out << " " << name;
+          if (name_printed) {
+            *name_printed = true;
+          }
+        }
+        for (uint32_t size : sizes) {
+          out << "[" << size << "]";
+        }
+        return true;
+      },
+      [&](const sem::Bool*) {
+        out << "bool";
+        return true;
+      },
+      [&](const sem::F32*) {
+        out << "float";
+        return true;
+      },
+      [&](const sem::I32*) {
+        out << "int";
+        return true;
+      },
+      [&](const sem::Matrix* mat) {
+        if (!EmitType(out, mat->type(), storage_class, access, "")) {
+          return false;
+        }
+        // Note: HLSL's matrices are declared as <type>NxM, where N is the
+        // number of rows and M is the number of columns. Despite HLSL's
+        // matrices being column-major by default, the index operator and
+        // constructors actually operate on row-vectors, where as WGSL operates
+        // on column vectors. To simplify everything we use the transpose of the
+        // matrices. See:
+        //
+        out << mat->columns() << "x" << mat->rows();
+        return true;
+      },
+      [&](const sem::Pointer*) {
         TINT_ICE(Writer, diagnostics_)
-            << "Runtime arrays may only exist in storage buffers, which should "
-               "have been transformed into a ByteAddressBuffer";
+            << "Attempting to emit pointer type. These should have been "
+               "removed with the InlinePointerLets transform";
         return false;
-      }
-      sizes.push_back(arr->Count());
-      base_type = arr->ElemType();
-    }
-    if (!EmitType(out, base_type, storage_class, access, "")) {
-      return false;
-    }
-    if (!name.empty()) {
-      out << " " << name;
-      if (name_printed) {
-        *name_printed = true;
-      }
-    }
-    for (uint32_t size : sizes) {
-      out << "[" << size << "]";
-    }
-  } else if (type->Is<sem::Bool>()) {
-    out << "bool";
-  } else if (type->Is<sem::F32>()) {
-    out << "float";
-  } else if (type->Is<sem::I32>()) {
-    out << "int";
-  } else if (auto* mat = type->As<sem::Matrix>()) {
-    if (!EmitType(out, mat->type(), storage_class, access, "")) {
-      return false;
-    }
-    // Note: HLSL's matrices are declared as <type>NxM, where N is the number of
-    // rows and M is the number of columns. Despite HLSL's matrices being
-    // column-major by default, the index operator and constructors actually
-    // operate on row-vectors, where as WGSL operates on column vectors.
-    // To simplify everything we use the transpose of the matrices.
-    // See:
-    //
-    out << mat->columns() << "x" << mat->rows();
-  } else if (type->Is<sem::Pointer>()) {
-    TINT_ICE(Writer, diagnostics_)
-        << "Attempting to emit pointer type. These should have been removed "
-           "with the InlinePointerLets transform";
-    return false;
-  } else if (auto* sampler = type->As<sem::Sampler>()) {
-    out << "Sampler";
-    if (sampler->IsComparison()) {
-      out << "Comparison";
-    }
-    out << "State";
-  } else if (auto* str = type->As<sem::Struct>()) {
-    out << StructName(str);
-  } else if (auto* tex = type->As<sem::Texture>()) {
-    auto* storage = tex->As<sem::StorageTexture>();
-    auto* ms = tex->As<sem::MultisampledTexture>();
-    auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
-    auto* sampled = tex->As<sem::SampledTexture>();
+      },
+      [&](const sem::Sampler* sampler) {
+        out << "Sampler";
+        if (sampler->IsComparison()) {
+          out << "Comparison";
+        }
+        out << "State";
+        return true;
+      },
+      [&](const sem::Struct* str) {
+        out << StructName(str);
+        return true;
+      },
+      [&](const sem::Texture* tex) {
+        auto* storage = tex->As<sem::StorageTexture>();
+        auto* ms = tex->As<sem::MultisampledTexture>();
+        auto* depth_ms = tex->As<sem::DepthMultisampledTexture>();
+        auto* sampled = tex->As<sem::SampledTexture>();
-    if (storage && storage->access() != ast::Access::kRead) {
-      out << "RW";
-    }
-    out << "Texture";
+        if (storage && storage->access() != ast::Access::kRead) {
+          out << "RW";
+        }
+        out << "Texture";
-    switch (tex->dim()) {
-      case ast::TextureDimension::k1d:
-        out << "1D";
-        break;
-      case ast::TextureDimension::k2d:
-        out << ((ms || depth_ms) ? "2DMS" : "2D");
-        break;
-      case ast::TextureDimension::k2dArray:
-        out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
-        break;
-      case ast::TextureDimension::k3d:
-        out << "3D";
-        break;
-      case ast::TextureDimension::kCube:
-        out << "Cube";
-        break;
-      case ast::TextureDimension::kCubeArray:
-        out << "CubeArray";
-        break;
-      default:
-        TINT_UNREACHABLE(Writer, diagnostics_)
-            << "unexpected TextureDimension " << tex->dim();
-        return false;
-    }
+        switch (tex->dim()) {
+          case ast::TextureDimension::k1d:
+            out << "1D";
+            break;
+          case ast::TextureDimension::k2d:
+            out << ((ms || depth_ms) ? "2DMS" : "2D");
+            break;
+          case ast::TextureDimension::k2dArray:
+            out << ((ms || depth_ms) ? "2DMSArray" : "2DArray");
+            break;
+          case ast::TextureDimension::k3d:
+            out << "3D";
+            break;
+          case ast::TextureDimension::kCube:
+            out << "Cube";
+            break;
+          case ast::TextureDimension::kCubeArray:
+            out << "CubeArray";
+            break;
+          default:
+            TINT_UNREACHABLE(Writer, diagnostics_)
+                << "unexpected TextureDimension " << tex->dim();
+            return false;
+        }
-    if (storage) {
-      auto* component = image_format_to_rwtexture_type(storage->texel_format());
-      if (component == nullptr) {
-        TINT_ICE(Writer, diagnostics_)
-            << "Unsupported StorageTexture TexelFormat: "
-            << static_cast<int>(storage->texel_format());
+        if (storage) {
+          auto* component =
+              image_format_to_rwtexture_type(storage->texel_format());
+          if (component == nullptr) {
+            TINT_ICE(Writer, diagnostics_)
+                << "Unsupported StorageTexture TexelFormat: "
+                << static_cast<int>(storage->texel_format());
+            return false;
+          }
+          out << "<" << component << ">";
+        } else if (depth_ms) {
+          out << "<float4>";
+        } else if (sampled || ms) {
+          auto* subtype = sampled ? sampled->type() : ms->type();
+          out << "<";
+          if (subtype->Is<sem::F32>()) {
+            out << "float4";
+          } else if (subtype->Is<sem::I32>()) {
+            out << "int4";
+          } else if (subtype->Is<sem::U32>()) {
+            out << "uint4";
+          } else {
+            TINT_ICE(Writer, diagnostics_)
+                << "Unsupported multisampled texture type";
+            return false;
+          }
+          out << ">";
+        }
+        return true;
+      },
+      [&](const sem::U32*) {
+        out << "uint";
+        return true;
+      },
+      [&](const sem::Vector* vec) {
+        auto width = vec->Width();
+        if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
+          out << "float" << width;
+        } else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
+          out << "int" << width;
+        } else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
+          out << "uint" << width;
+        } else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
+          out << "bool" << width;
+        } else {
+          out << "vector<";
+          if (!EmitType(out, vec->type(), storage_class, access, "")) {
+            return false;
+          }
+          out << ", " << width << ">";
+        }
+        return true;
+      },
+      [&](const sem::Atomic* atomic) {
+        return EmitType(out, atomic->Type(), storage_class, access, name);
+      },
+      [&](const sem::Void*) {
+        out << "void";
+        return true;
+      },
+      [&](Default) {
+        diagnostics_.add_error(diag::System::Writer,
+                               "unknown type in EmitType");
         return false;
-      }
-      out << "<" << component << ">";
-    } else if (depth_ms) {
-      out << "<float4>";
-    } else if (sampled || ms) {
-      auto* subtype = sampled ? sampled->type() : ms->type();
-      out << "<";
-      if (subtype->Is<sem::F32>()) {
-        out << "float4";
-      } else if (subtype->Is<sem::I32>()) {
-        out << "int4";
-      } else if (subtype->Is<sem::U32>()) {
-        out << "uint4";
-      } else {
-        TINT_ICE(Writer, diagnostics_)
-            << "Unsupported multisampled texture type";
-        return false;
-      }
-      out << ">";
-    }
-  } else if (type->Is<sem::U32>()) {
-    out << "uint";
-  } else if (auto* vec = type->As<sem::Vector>()) {
-    auto width = vec->Width();
-    if (vec->type()->Is<sem::F32>() && width >= 1 && width <= 4) {
-      out << "float" << width;
-    } else if (vec->type()->Is<sem::I32>() && width >= 1 && width <= 4) {
-      out << "int" << width;
-    } else if (vec->type()->Is<sem::U32>() && width >= 1 && width <= 4) {
-      out << "uint" << width;
-    } else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
-      out << "bool" << width;
-    } else {
-      out << "vector<";
-      if (!EmitType(out, vec->type(), storage_class, access, "")) {
-        return false;
-      }
-      out << ", " << width << ">";
-    }
-  } else if (auto* atomic = type->As<sem::Atomic>()) {
-    if (!EmitType(out, atomic->Type(), storage_class, access, name)) {
-      return false;
-    }
-  } else if (type->Is<sem::Void>()) {
-    out << "void";
-  } else {
-    diagnostics_.add_error(diag::System::Writer, "unknown type in EmitType");
-    return false;
-  }
-  return true;
+      });
 bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
diff --git a/src/writer/msl/ b/src/writer/msl/
index ec2d748..129e969 100644
--- a/src/writer/msl/
+++ b/src/writer/msl/
@@ -538,23 +538,25 @@
                              const ast::CallExpression* expr) {
   auto* call = program_->Sem().Get(expr);
   auto* target = call->Target();
-  if (auto* func = target->As<sem::Function>()) {
-    return EmitFunctionCall(out, call, func);
-  }
-  if (auto* builtin = target->As<sem::Builtin>()) {
-    return EmitBuiltinCall(out, call, builtin);
-  }
-  if (auto* conv = target->As<sem::TypeConversion>()) {
-    return EmitTypeConversion(out, call, conv);
-  }
-  if (auto* ctor = target->As<sem::TypeConstructor>()) {
-    return EmitTypeConstructor(out, call, ctor);
-  }
-  TINT_ICE(Writer, diagnostics_)
-      << "unhandled call target: " << target->TypeInfo().name;
-  return false;
+  return Switch(
+      target,
+      [&](const sem::Function* func) {
+        return EmitFunctionCall(out, call, func);
+      },
+      [&](const sem::Builtin* builtin) {
+        return EmitBuiltinCall(out, call, builtin);
+      },
+      [&](const sem::TypeConversion* conv) {
+        return EmitTypeConversion(out, call, conv);
+      },
+      [&](const sem::TypeConstructor* ctor) {
+        return EmitTypeConstructor(out, call, ctor);
+      },
+      [&](Default) {
+        TINT_ICE(Writer, diagnostics_)
+            << "unhandled call target: " << target->TypeInfo().name;
+        return false;
+      });
 bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
@@ -1476,106 +1478,128 @@
 bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
-  if (type->Is<sem::Bool>()) {
-    out << "false";
-  } else if (type->Is<sem::F32>()) {
-    out << "0.0f";
-  } else if (type->Is<sem::I32>()) {
-    out << "0";
-  } else if (type->Is<sem::U32>()) {
-    out << "0u";
-  } else if (auto* vec = type->As<sem::Vector>()) {
-    return EmitZeroValue(out, vec->type());
-  } else if (auto* mat = type->As<sem::Matrix>()) {
-    if (!EmitType(out, mat, "")) {
-      return false;
-    }
-    out << "(";
-    if (!EmitZeroValue(out, mat->type())) {
-      return false;
-    }
-    out << ")";
-  } else if (auto* arr = type->As<sem::Array>()) {
-    out << "{";
-    if (!EmitZeroValue(out, arr->ElemType())) {
-      return false;
-    }
-    out << "}";
-  } else if (type->As<sem::Struct>()) {
-    out << "{}";
-  } else {
-    diagnostics_.add_error(
-        diag::System::Writer,
-        "Invalid type for zero emission: " + type->type_name());
-    return false;
-  }
-  return true;
+  return Switch(
+      type,
+      [&](const sem::Bool*) {
+        out << "false";
+        return true;
+      },
+      [&](const sem::F32*) {
+        out << "0.0f";
+        return true;
+      },
+      [&](const sem::I32*) {
+        out << "0";
+        return true;
+      },
+      [&](const sem::U32*) {
+        out << "0u";
+        return true;
+      },
+      [&](const sem::Vector* vec) {  //
+        return EmitZeroValue(out, vec->type());
+      },
+      [&](const sem::Matrix* mat) {
+        if (!EmitType(out, mat, "")) {
+          return false;
+        }
+        out << "(";
+        TINT_DEFER(out << ")");
+        return EmitZeroValue(out, mat->type());
+      },
+      [&](const sem::Array* arr) {
+        out << "{";
+        TINT_DEFER(out << "}");
+        return EmitZeroValue(out, arr->ElemType());
+      },
+      [&](const sem::Struct*) {
+        out << "{}";
+        return true;
+      },
+      [&](Default) {
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "Invalid type for zero emission: " + type->type_name());
+        return false;
+      });
 bool GeneratorImpl::EmitLiteral(std::ostream& out,
                                 const ast::LiteralExpression* lit) {
-  if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
-    out << (l->value ? "true" : "false");
-  } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
-    if (std::isinf(fl->value)) {
-      out << (fl->value >= 0 ? "INFINITY" : "-INFINITY");
-    } else if (std::isnan(fl->value)) {
-      out << "NAN";
-    } else {
-      out << FloatToString(fl->value) << "f";
-    }
-  } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
-    // MSL (and C++) parse `-2147483648` as a `long` because it parses unary
-    // minus and `2147483648` as separate tokens, and the latter doesn't
-    // fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To avoid
-    // issues with `long` to `int` casts, emit `(2147483647 - 1)` instead, which
-    // ensures the expression type is `int`.
-    const auto int_min = std::numeric_limits<int32_t>::min();
-    if (sl->ValueAsI32() == int_min) {
-      out << "(" << int_min + 1 << " - 1)";
-    } else {
-      out << sl->value;
-    }
-  } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
-    out << ul->value << "u";
-  } else {
-    diagnostics_.add_error(diag::System::Writer, "unknown literal type");
-    return false;
-  }
-  return true;
+  return Switch(
+      lit,
+      [&](const ast::BoolLiteralExpression* l) {
+        out << (l->value ? "true" : "false");
+        return true;
+      },
+      [&](const ast::FloatLiteralExpression* l) {
+        if (std::isinf(l->value)) {
+          out << (l->value >= 0 ? "INFINITY" : "-INFINITY");
+        } else if (std::isnan(l->value)) {
+          out << "NAN";
+        } else {
+          out << FloatToString(l->value) << "f";
+        }
+        return true;
+      },
+      [&](const ast::SintLiteralExpression* l) {
+        // MSL (and C++) parse `-2147483648` as a `long` because it parses unary
+        // minus and `2147483648` as separate tokens, and the latter doesn't
+        // fit into an (32-bit) `int`. WGSL, OTOH, parses this as an `i32`. To
+        // avoid issues with `long` to `int` casts, emit `(2147483647 - 1)`
+        // instead, which ensures the expression type is `int`.
+        const auto int_min = std::numeric_limits<int32_t>::min();
+        if (l->ValueAsI32() == int_min) {
+          out << "(" << int_min + 1 << " - 1)";
+        } else {
+          out << l->value;
+        }
+        return true;
+      },
+      [&](const ast::UintLiteralExpression* l) {
+        out << l->value << "u";
+        return true;
+      },
+      [&](Default) {
+        diagnostics_.add_error(diag::System::Writer, "unknown literal type");
+        return false;
+      });
 bool GeneratorImpl::EmitExpression(std::ostream& out,
                                    const ast::Expression* expr) {
-  if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
-    return EmitIndexAccessor(out, a);
-  }
-  if (auto* b = expr->As<ast::BinaryExpression>()) {
-    return EmitBinary(out, b);
-  }
-  if (auto* b = expr->As<ast::BitcastExpression>()) {
-    return EmitBitcast(out, b);
-  }
-  if (auto* c = expr->As<ast::CallExpression>()) {
-    return EmitCall(out, c);
-  }
-  if (auto* i = expr->As<ast::IdentifierExpression>()) {
-    return EmitIdentifier(out, i);
-  }
-  if (auto* l = expr->As<ast::LiteralExpression>()) {
-    return EmitLiteral(out, l);
-  }
-  if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
-    return EmitMemberAccessor(out, m);
-  }
-  if (auto* u = expr->As<ast::UnaryOpExpression>()) {
-    return EmitUnaryOp(out, u);
-  }
-  diagnostics_.add_error(
-      diag::System::Writer,
-      "unknown expression type: " + std::string(expr->TypeInfo().name));
-  return false;
+  return Switch(
+      expr,
+      [&](const ast::IndexAccessorExpression* a) {  //
+        return EmitIndexAccessor(out, a);
+      },
+      [&](const ast::BinaryExpression* b) {  //
+        return EmitBinary(out, b);
+      },
+      [&](const ast::BitcastExpression* b) {  //
+        return EmitBitcast(out, b);
+      },
+      [&](const ast::CallExpression* c) {  //
+        return EmitCall(out, c);
+      },
+      [&](const ast::IdentifierExpression* i) {  //
+        return EmitIdentifier(out, i);
+      },
+      [&](const ast::LiteralExpression* l) {  //
+        return EmitLiteral(out, l);
+      },
+      [&](const ast::MemberAccessorExpression* m) {  //
+        return EmitMemberAccessor(out, m);
+      },
+      [&](const ast::UnaryOpExpression* u) {  //
+        return EmitUnaryOp(out, u);
+      },
+      [&](Default) {  //
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "unknown expression type: " + std::string(expr->TypeInfo().name));
+        return false;
+      });
 void GeneratorImpl::EmitStage(std::ostream& out, ast::PipelineStage stage) {
@@ -2106,57 +2130,60 @@
 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
-  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
-    return EmitAssign(a);
-  }
-  if (auto* b = stmt->As<ast::BlockStatement>()) {
-    return EmitBlock(b);
-  }
-  if (auto* b = stmt->As<ast::BreakStatement>()) {
-    return EmitBreak(b);
-  }
-  if (auto* c = stmt->As<ast::CallStatement>()) {
-    auto out = line();
-    if (!EmitCall(out, c->expr)) {
-      return false;
-    }
-    out << ";";
-    return true;
-  }
-  if (auto* c = stmt->As<ast::ContinueStatement>()) {
-    return EmitContinue(c);
-  }
-  if (auto* d = stmt->As<ast::DiscardStatement>()) {
-    return EmitDiscard(d);
-  }
-  if (stmt->As<ast::FallthroughStatement>()) {
-    line() << "/* fallthrough */";
-    return true;
-  }
-  if (auto* i = stmt->As<ast::IfStatement>()) {
-    return EmitIf(i);
-  }
-  if (auto* l = stmt->As<ast::LoopStatement>()) {
-    return EmitLoop(l);
-  }
-  if (auto* l = stmt->As<ast::ForLoopStatement>()) {
-    return EmitForLoop(l);
-  }
-  if (auto* r = stmt->As<ast::ReturnStatement>()) {
-    return EmitReturn(r);
-  }
-  if (auto* s = stmt->As<ast::SwitchStatement>()) {
-    return EmitSwitch(s);
-  }
-  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
-    auto* var = program_->Sem().Get(v->variable);
-    return EmitVariable(var);
-  }
-  diagnostics_.add_error(
-      diag::System::Writer,
-      "unknown statement type: " + std::string(stmt->TypeInfo().name));
-  return false;
+  return Switch(
+      stmt,
+      [&](const ast::AssignmentStatement* a) {  //
+        return EmitAssign(a);
+      },
+      [&](const ast::BlockStatement* b) {  //
+        return EmitBlock(b);
+      },
+      [&](const ast::BreakStatement* b) {  //
+        return EmitBreak(b);
+      },
+      [&](const ast::CallStatement* c) {  //
+        auto out = line();
+        if (!EmitCall(out, c->expr)) {  //
+          return false;
+        }
+        out << ";";
+        return true;
+      },
+      [&](const ast::ContinueStatement* c) {  //
+        return EmitContinue(c);
+      },
+      [&](const ast::DiscardStatement* d) {  //
+        return EmitDiscard(d);
+      },
+      [&](const ast::FallthroughStatement*) {  //
+        line() << "/* fallthrough */";
+        return true;
+      },
+      [&](const ast::IfStatement* i) {  //
+        return EmitIf(i);
+      },
+      [&](const ast::LoopStatement* l) {  //
+        return EmitLoop(l);
+      },
+      [&](const ast::ForLoopStatement* l) {  //
+        return EmitForLoop(l);
+      },
+      [&](const ast::ReturnStatement* r) {  //
+        return EmitReturn(r);
+      },
+      [&](const ast::SwitchStatement* s) {  //
+        return EmitSwitch(s);
+      },
+      [&](const ast::VariableDeclStatement* v) {  //
+        auto* var = program_->Sem().Get(v->variable);
+        return EmitVariable(var);
+      },
+      [&](Default) {
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "unknown statement type: " + std::string(stmt->TypeInfo().name));
+        return false;
+      });
 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {
@@ -2204,203 +2231,210 @@
   if (name_printed) {
     *name_printed = false;
-  if (auto* atomic = type->As<sem::Atomic>()) {
-    if (atomic->Type()->Is<sem::I32>()) {
-      out << "atomic_int";
-      return true;
-    }
-    if (atomic->Type()->Is<sem::U32>()) {
-      out << "atomic_uint";
-      return true;
-    }
-    TINT_ICE(Writer, diagnostics_)
-        << "unhandled atomic type " << atomic->Type()->type_name();
-    return false;
-  }
-  if (auto* ary = type->As<sem::Array>()) {
-    const sem::Type* base_type = ary;
-    std::vector<uint32_t> sizes;
-    while (auto* arr = base_type->As<sem::Array>()) {
-      if (arr->IsRuntimeSized()) {
-        sizes.push_back(1);
-      } else {
-        sizes.push_back(arr->Count());
-      }
-      base_type = arr->ElemType();
-    }
-    if (!EmitType(out, base_type, "")) {
-      return false;
-    }
-    if (!name.empty()) {
-      out << " " << name;
-      if (name_printed) {
-        *name_printed = true;
-      }
-    }
-    for (uint32_t size : sizes) {
-      out << "[" << size << "]";
-    }
-    return true;
-  }
-  if (type->Is<sem::Bool>()) {
-    out << "bool";
-    return true;
-  }
-  if (type->Is<sem::F32>()) {
-    out << "float";
-    return true;
-  }
-  if (type->Is<sem::I32>()) {
-    out << "int";
-    return true;
-  }
-  if (auto* mat = type->As<sem::Matrix>()) {
-    if (!EmitType(out, mat->type(), "")) {
-      return false;
-    }
-    out << mat->columns() << "x" << mat->rows();
-    return true;
-  }
-  if (auto* ptr = type->As<sem::Pointer>()) {
-    if (ptr->Access() == ast::Access::kRead) {
-      out << "const ";
-    }
-    if (!EmitStorageClass(out, ptr->StorageClass())) {
-      return false;
-    }
-    out << " ";
-    if (ptr->StoreType()->Is<sem::Array>()) {
-      std::string inner = "(*" + name + ")";
-      if (!EmitType(out, ptr->StoreType(), inner)) {
+  return Switch(
+      type,
+      [&](const sem::Atomic* atomic) {
+        if (atomic->Type()->Is<sem::I32>()) {
+          out << "atomic_int";
+          return true;
+        }
+        if (atomic->Type()->Is<sem::U32>()) {
+          out << "atomic_uint";
+          return true;
+        }
+        TINT_ICE(Writer, diagnostics_)
+            << "unhandled atomic type " << atomic->Type()->type_name();
         return false;
-      }
-      if (name_printed) {
-        *name_printed = true;
-      }
-    } else {
-      if (!EmitType(out, ptr->StoreType(), "")) {
+      },
+      [&](const sem::Array* ary) {
+        const sem::Type* base_type = ary;
+        std::vector<uint32_t> sizes;
+        while (auto* arr = base_type->As<sem::Array>()) {
+          if (arr->IsRuntimeSized()) {
+            sizes.push_back(1);
+          } else {
+            sizes.push_back(arr->Count());
+          }
+          base_type = arr->ElemType();
+        }
+        if (!EmitType(out, base_type, "")) {
+          return false;
+        }
+        if (!name.empty()) {
+          out << " " << name;
+          if (name_printed) {
+            *name_printed = true;
+          }
+        }
+        for (uint32_t size : sizes) {
+          out << "[" << size << "]";
+        }
+        return true;
+      },
+      [&](const sem::Bool*) {
+        out << "bool";
+        return true;
+      },
+      [&](const sem::F32*) {
+        out << "float";
+        return true;
+      },
+      [&](const sem::I32*) {
+        out << "int";
+        return true;
+      },
+      [&](const sem::Matrix* mat) {
+        if (!EmitType(out, mat->type(), "")) {
+          return false;
+        }
+        out << mat->columns() << "x" << mat->rows();
+        return true;
+      },
+      [&](const sem::Pointer* ptr) {
+        if (ptr->Access() == ast::Access::kRead) {
+          out << "const ";
+        }
+        if (!EmitStorageClass(out, ptr->StorageClass())) {
+          return false;
+        }
+        out << " ";
+        if (ptr->StoreType()->Is<sem::Array>()) {
+          std::string inner = "(*" + name + ")";
+          if (!EmitType(out, ptr->StoreType(), inner)) {
+            return false;
+          }
+          if (name_printed) {
+            *name_printed = true;
+          }
+        } else {
+          if (!EmitType(out, ptr->StoreType(), "")) {
+            return false;
+          }
+          out << "* " << name;
+          if (name_printed) {
+            *name_printed = true;
+          }
+        }
+        return true;
+      },
+      [&](const sem::Sampler*) {
+        out << "sampler";
+        return true;
+      },
+      [&](const sem::Struct* str) {
+        // The struct type emits as just the name. The declaration would be
+        // emitted as part of emitting the declared types.
+        out << StructName(str);
+        return true;
+      },
+      [&](const sem::Texture* tex) {
+        if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
+          out << "depth";
+        } else {
+          out << "texture";
+        }
+        switch (tex->dim()) {
+          case ast::TextureDimension::k1d:
+            out << "1d";
+            break;
+          case ast::TextureDimension::k2d:
+            out << "2d";
+            break;
+          case ast::TextureDimension::k2dArray:
+            out << "2d_array";
+            break;
+          case ast::TextureDimension::k3d:
+            out << "3d";
+            break;
+          case ast::TextureDimension::kCube:
+            out << "cube";
+            break;
+          case ast::TextureDimension::kCubeArray:
+            out << "cube_array";
+            break;
+          default:
+            diagnostics_.add_error(diag::System::Writer,
+                                   "Invalid texture dimensions");
+            return false;
+        }
+        if (tex->IsAnyOf<sem::MultisampledTexture,
+                         sem::DepthMultisampledTexture>()) {
+          out << "_ms";
+        }
+        out << "<";
+        TINT_DEFER(out << ">");
+        return Switch(
+            tex,
+            [&](const sem::DepthTexture*) {
+              out << "float, access::sample";
+              return true;
+            },
+            [&](const sem::DepthMultisampledTexture*) {
+              out << "float, access::read";
+              return true;
+            },
+            [&](const sem::StorageTexture* storage) {
+              if (!EmitType(out, storage->type(), "")) {
+                return false;
+              }
+              std::string access_str;
+              if (storage->access() == ast::Access::kRead) {
+                out << ", access::read";
+              } else if (storage->access() == ast::Access::kWrite) {
+                out << ", access::write";
+              } else {
+                diagnostics_.add_error(
+                    diag::System::Writer,
+                    "Invalid access control for storage texture");
+                return false;
+              }
+              return true;
+            },
+            [&](const sem::MultisampledTexture* ms) {
+              if (!EmitType(out, ms->type(), "")) {
+                return false;
+              }
+              out << ", access::read";
+              return true;
+            },
+            [&](const sem::SampledTexture* sampled) {
+              if (!EmitType(out, sampled->type(), "")) {
+                return false;
+              }
+              out << ", access::sample";
+              return true;
+            },
+            [&](Default) {
+              diagnostics_.add_error(diag::System::Writer,
+                                     "invalid texture type");
+              return false;
+            });
+      },
+      [&](const sem::U32*) {
+        out << "uint";
+        return true;
+      },
+      [&](const sem::Vector* vec) {
+        if (!EmitType(out, vec->type(), "")) {
+          return false;
+        }
+        out << vec->Width();
+        return true;
+      },
+      [&](const sem::Void*) {
+        out << "void";
+        return true;
+      },
+      [&](Default) {
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "unknown type in EmitType: " + type->type_name());
         return false;
-      }
-      out << "* " << name;
-      if (name_printed) {
-        *name_printed = true;
-      }
-    }
-    return true;
-  }
-  if (type->Is<sem::Sampler>()) {
-    out << "sampler";
-    return true;
-  }
-  if (auto* str = type->As<sem::Struct>()) {
-    // The struct type emits as just the name. The declaration would be emitted
-    // as part of emitting the declared types.
-    out << StructName(str);
-    return true;
-  }
-  if (auto* tex = type->As<sem::Texture>()) {
-    if (tex->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
-      out << "depth";
-    } else {
-      out << "texture";
-    }
-    switch (tex->dim()) {
-      case ast::TextureDimension::k1d:
-        out << "1d";
-        break;
-      case ast::TextureDimension::k2d:
-        out << "2d";
-        break;
-      case ast::TextureDimension::k2dArray:
-        out << "2d_array";
-        break;
-      case ast::TextureDimension::k3d:
-        out << "3d";
-        break;
-      case ast::TextureDimension::kCube:
-        out << "cube";
-        break;
-      case ast::TextureDimension::kCubeArray:
-        out << "cube_array";
-        break;
-      default:
-        diagnostics_.add_error(diag::System::Writer,
-                               "Invalid texture dimensions");
-        return false;
-    }
-    if (tex->IsAnyOf<sem::MultisampledTexture,
-                     sem::DepthMultisampledTexture>()) {
-      out << "_ms";
-    }
-    out << "<";
-    if (tex->Is<sem::DepthTexture>()) {
-      out << "float, access::sample";
-    } else if (tex->Is<sem::DepthMultisampledTexture>()) {
-      out << "float, access::read";
-    } else if (auto* storage = tex->As<sem::StorageTexture>()) {
-      if (!EmitType(out, storage->type(), "")) {
-        return false;
-      }
-      std::string access_str;
-      if (storage->access() == ast::Access::kRead) {
-        out << ", access::read";
-      } else if (storage->access() == ast::Access::kWrite) {
-        out << ", access::write";
-      } else {
-        diagnostics_.add_error(diag::System::Writer,
-                               "Invalid access control for storage texture");
-        return false;
-      }
-    } else if (auto* ms = tex->As<sem::MultisampledTexture>()) {
-      if (!EmitType(out, ms->type(), "")) {
-        return false;
-      }
-      out << ", access::read";
-    } else if (auto* sampled = tex->As<sem::SampledTexture>()) {
-      if (!EmitType(out, sampled->type(), "")) {
-        return false;
-      }
-      out << ", access::sample";
-    } else {
-      diagnostics_.add_error(diag::System::Writer, "invalid texture type");
-      return false;
-    }
-    out << ">";
-    return true;
-  }
-  if (type->Is<sem::U32>()) {
-    out << "uint";
-    return true;
-  }
-  if (auto* vec = type->As<sem::Vector>()) {
-    if (!EmitType(out, vec->type(), "")) {
-      return false;
-    }
-    out << vec->Width();
-    return true;
-  }
-  if (type->Is<sem::Void>()) {
-    out << "void";
-    return true;
-  }
-  diagnostics_.add_error(diag::System::Writer,
-                         "unknown type in EmitType: " + type->type_name());
-  return false;
+      });
 bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
@@ -2542,55 +2576,72 @@
     // Emit attributes
     if (auto* decl = mem->Declaration()) {
       for (auto* attr : decl->attributes) {
-        if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
-          auto name = builtin_to_attribute(builtin->builtin);
-          if (name.empty()) {
-            diagnostics_.add_error(diag::System::Writer, "unknown builtin");
-            return false;
-          }
-          out << " [[" << name << "]]";
-        } else if (auto* loc = attr->As<ast::LocationAttribute>()) {
-          auto& pipeline_stage_uses = str->PipelineStageUses();
-          if (pipeline_stage_uses.size() != 1) {
-            TINT_ICE(Writer, diagnostics_)
-                << "invalid entry point IO struct uses";
-          }
+        bool ok = Switch(
+            attr,
+            [&](const ast::BuiltinAttribute* builtin) {
+              auto name = builtin_to_attribute(builtin->builtin);
+              if (name.empty()) {
+                diagnostics_.add_error(diag::System::Writer, "unknown builtin");
+                return false;
+              }
+              out << " [[" << name << "]]";
+              return true;
+            },
+            [&](const ast::LocationAttribute* loc) {
+              auto& pipeline_stage_uses = str->PipelineStageUses();
+              if (pipeline_stage_uses.size() != 1) {
+                TINT_ICE(Writer, diagnostics_)
+                    << "invalid entry point IO struct uses";
+                return false;
+              }
-          if (pipeline_stage_uses.count(
-                  sem::PipelineStageUsage::kVertexInput)) {
-            out << " [[attribute(" + std::to_string(loc->value) + ")]]";
-          } else if (pipeline_stage_uses.count(
-                         sem::PipelineStageUsage::kVertexOutput)) {
-            out << " [[user(locn" + std::to_string(loc->value) + ")]]";
-          } else if (pipeline_stage_uses.count(
-                         sem::PipelineStageUsage::kFragmentInput)) {
-            out << " [[user(locn" + std::to_string(loc->value) + ")]]";
-          } else if (pipeline_stage_uses.count(
-                         sem::PipelineStageUsage::kFragmentOutput)) {
-            out << " [[color(" + std::to_string(loc->value) + ")]]";
-          } else {
-            TINT_ICE(Writer, diagnostics_)
-                << "invalid use of location attribute";
-          }
-        } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
-          auto name = interpolation_to_attribute(interpolate->type,
-                                                 interpolate->sampling);
-          if (name.empty()) {
-            diagnostics_.add_error(diag::System::Writer,
-                                   "unknown interpolation attribute");
-            return false;
-          }
-          out << " [[" << name << "]]";
-        } else if (attr->Is<ast::InvariantAttribute>()) {
-          if (invariant_define_name_.empty()) {
-            invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
-          }
-          out << " " << invariant_define_name_;
-        } else if (!attr->IsAnyOf<ast::StructMemberOffsetAttribute,
-                                  ast::StructMemberAlignAttribute,
-                                  ast::StructMemberSizeAttribute>()) {
-          TINT_ICE(Writer, diagnostics_)
-              << "unhandled struct member attribute: " << attr->Name();
+              if (pipeline_stage_uses.count(
+                      sem::PipelineStageUsage::kVertexInput)) {
+                out << " [[attribute(" + std::to_string(loc->value) + ")]]";
+              } else if (pipeline_stage_uses.count(
+                             sem::PipelineStageUsage::kVertexOutput)) {
+                out << " [[user(locn" + std::to_string(loc->value) + ")]]";
+              } else if (pipeline_stage_uses.count(
+                             sem::PipelineStageUsage::kFragmentInput)) {
+                out << " [[user(locn" + std::to_string(loc->value) + ")]]";
+              } else if (pipeline_stage_uses.count(
+                             sem::PipelineStageUsage::kFragmentOutput)) {
+                out << " [[color(" + std::to_string(loc->value) + ")]]";
+              } else {
+                TINT_ICE(Writer, diagnostics_)
+                    << "invalid use of location decoration";
+                return false;
+              }
+              return true;
+            },
+            [&](const ast::InterpolateAttribute* interpolate) {
+              auto name = interpolation_to_attribute(interpolate->type,
+                                                     interpolate->sampling);
+              if (name.empty()) {
+                diagnostics_.add_error(diag::System::Writer,
+                                       "unknown interpolation attribute");
+                return false;
+              }
+              out << " [[" << name << "]]";
+              return true;
+            },
+            [&](const ast::InvariantAttribute*) {
+              if (invariant_define_name_.empty()) {
+                invariant_define_name_ = UniqueIdentifier("TINT_INVARIANT");
+              }
+              out << " " << invariant_define_name_;
+              return true;
+            },
+            [&](const ast::StructMemberOffsetAttribute*) { return true; },
+            [&](const ast::StructMemberAlignAttribute*) { return true; },
+            [&](const ast::StructMemberSizeAttribute*) { return true; },
+            [&](Default) {
+              TINT_ICE(Writer, diagnostics_)
+                  << "unhandled struct member attribute: " << attr->Name();
+              return false;
+            });
+        if (!ok) {
+          return false;
@@ -2796,77 +2847,96 @@
 GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(
     const sem::Type* ty) {
-  if (ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
-    //
-    // 2.1 Scalar Data Types
-    return {4, 4};
-  }
+  return Switch(
+      ty,
-  if (auto* vec = ty->As<sem::Vector>()) {
-    auto num_els = vec->Width();
-    auto* el_ty = vec->type();
-    if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
-      // Use a packed_vec type for 3-element vectors only.
-      if (num_els == 3) {
+      //
+      // 2.1 Scalar Data Types
+      [&](const sem::U32*) {
+        return SizeAndAlign{4, 4};
+      },
+      [&](const sem::I32*) {
+        return SizeAndAlign{4, 4};
+      },
+      [&](const sem::F32*) {
+        return SizeAndAlign{4, 4};
+      },
+      [&](const sem::Vector* vec) {
+        auto num_els = vec->Width();
+        auto* el_ty = vec->type();
+        if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
+          // Use a packed_vec type for 3-element vectors only.
+          if (num_els == 3) {
+            //
+            // 2.2.3 Packed Vector Types
+            return SizeAndAlign{num_els * 4, 4};
+          } else {
+            //
+            // 2.2 Vector Data Types
+            return SizeAndAlign{num_els * 4, num_els * 4};
+          }
+        }
+        TINT_UNREACHABLE(Writer, diagnostics_)
+            << "Unhandled vector element type " << el_ty->TypeInfo().name;
+        return SizeAndAlign{};
+      },
+      [&](const sem::Matrix* mat) {
-        // 2.2.3 Packed Vector Types
-        return SizeAndAlign{num_els * 4, 4};
-      } else {
-        //
-        // 2.2 Vector Data Types
-        return SizeAndAlign{num_els * 4, num_els * 4};
-      }
-    }
-  }
+        // 2.3 Matrix Data Types
+        auto cols = mat->columns();
+        auto rows = mat->rows();
+        auto* el_ty = mat->type();
+        if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
+          static constexpr SizeAndAlign table[] = {
+              /* float2x2 */ {16, 8},
+              /* float2x3 */ {32, 16},
+              /* float2x4 */ {32, 16},
+              /* float3x2 */ {24, 8},
+              /* float3x3 */ {48, 16},
+              /* float3x4 */ {48, 16},
+              /* float4x2 */ {32, 8},
+              /* float4x3 */ {64, 16},
+              /* float4x4 */ {64, 16},
+          };
+          if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
+            return table[(3 * (cols - 2)) + (rows - 2)];
+          }
+        }
-  if (auto* mat = ty->As<sem::Matrix>()) {
-    //
-    // 2.3 Matrix Data Types
-    auto cols = mat->columns();
-    auto rows = mat->rows();
-    auto* el_ty = mat->type();
-    if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
-      static constexpr SizeAndAlign table[] = {
-          /* float2x2 */ {16, 8},
-          /* float2x3 */ {32, 16},
-          /* float2x4 */ {32, 16},
-          /* float3x2 */ {24, 8},
-          /* float3x3 */ {48, 16},
-          /* float3x4 */ {48, 16},
-          /* float4x2 */ {32, 8},
-          /* float4x3 */ {64, 16},
-          /* float4x4 */ {64, 16},
-      };
-      if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
-        return table[(3 * (cols - 2)) + (rows - 2)];
-      }
-    }
-  }
+        TINT_UNREACHABLE(Writer, diagnostics_)
+            << "Unhandled matrix element type " << el_ty->TypeInfo().name;
+        return SizeAndAlign{};
+      },
-  if (auto* arr = ty->As<sem::Array>()) {
-    if (!arr->IsStrideImplicit()) {
-      TINT_ICE(Writer, diagnostics_)
-          << "arrays with explicit strides should have "
-             "removed with the PadArrayElements transform";
-      return {};
-    }
-    auto num_els = std::max<uint32_t>(arr->Count(), 1);
-    return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
-  }
+      [&](const sem::Array* arr) {
+        if (!arr->IsStrideImplicit()) {
+          TINT_ICE(Writer, diagnostics_)
+              << "arrays with explicit strides should have "
+                 "removed with the PadArrayElements transform";
+          return SizeAndAlign{};
+        }
+        auto num_els = std::max<uint32_t>(arr->Count(), 1);
+        return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
+      },
-  if (auto* str = ty->As<sem::Struct>()) {
-    // TODO( There's an assumption here that MSL's default
-    // structure size and alignment matches WGSL's. We need to confirm this.
-    return SizeAndAlign{str->Size(), str->Align()};
-  }
+      [&](const sem::Struct* str) {
+        // TODO( There's an assumption here that MSL's
+        // default structure size and alignment matches WGSL's. We need to
+        // confirm this.
+        return SizeAndAlign{str->Size(), str->Align()};
+      },
-  if (auto* atomic = ty->As<sem::Atomic>()) {
-    return MslPackedTypeSizeAndAlign(atomic->Type());
-  }
+      [&](const sem::Atomic* atomic) {
+        return MslPackedTypeSizeAndAlign(atomic->Type());
+      },
-  TINT_UNREACHABLE(Writer, diagnostics_)
-      << "Unhandled type " << ty->TypeInfo().name;
-  return {};
+      [&](Default) {
+        TINT_UNREACHABLE(Writer, diagnostics_)
+            << "Unhandled type " << ty->TypeInfo().name;
+        return SizeAndAlign{};
+      });
 template <typename F>
diff --git a/src/writer/spirv/ b/src/writer/spirv/
index 19707c0..932ae8c 100644
--- a/src/writer/spirv/
+++ b/src/writer/spirv/
@@ -560,33 +560,37 @@
 uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
-  if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
-    return GenerateAccessorExpression(a);
-  }
-  if (auto* b = expr->As<ast::BinaryExpression>()) {
-    return GenerateBinaryExpression(b);
-  }
-  if (auto* b = expr->As<ast::BitcastExpression>()) {
-    return GenerateBitcastExpression(b);
-  }
-  if (auto* c = expr->As<ast::CallExpression>()) {
-    return GenerateCallExpression(c);
-  }
-  if (auto* i = expr->As<ast::IdentifierExpression>()) {
-    return GenerateIdentifierExpression(i);
-  }
-  if (auto* l = expr->As<ast::LiteralExpression>()) {
-    return GenerateLiteralIfNeeded(nullptr, l);
-  }
-  if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
-    return GenerateAccessorExpression(m);
-  }
-  if (auto* u = expr->As<ast::UnaryOpExpression>()) {
-    return GenerateUnaryOpExpression(u);
-  }
-  error_ = "unknown expression type: " + std::string(expr->TypeInfo().name);
-  return 0;
+  return Switch(
+      expr,
+      [&](const ast::IndexAccessorExpression* a) {  //
+        return GenerateAccessorExpression(a);
+      },
+      [&](const ast::BinaryExpression* b) {  //
+        return GenerateBinaryExpression(b);
+      },
+      [&](const ast::BitcastExpression* b) {  //
+        return GenerateBitcastExpression(b);
+      },
+      [&](const ast::CallExpression* c) {  //
+        return GenerateCallExpression(c);
+      },
+      [&](const ast::IdentifierExpression* i) {  //
+        return GenerateIdentifierExpression(i);
+      },
+      [&](const ast::LiteralExpression* l) {  //
+        return GenerateLiteralIfNeeded(nullptr, l);
+      },
+      [&](const ast::MemberAccessorExpression* m) {  //
+        return GenerateAccessorExpression(m);
+      },
+      [&](const ast::UnaryOpExpression* u) {  //
+        return GenerateUnaryOpExpression(u);
+      },
+      [&](Default) -> uint32_t {
+        error_ =
+            "unknown expression type: " + std::string(expr->TypeInfo().name);
+        return 0;
+      });
 bool Builder::GenerateFunction(const ast::Function* func_ast) {
@@ -861,33 +865,56 @@
   push_type(spv::Op::OpVariable, std::move(ops));
   for (auto* attr : var->attributes) {
-    if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
-      push_annot(spv::Op::OpDecorate,
-                 {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
-                  Operand::Int(
-                      ConvertBuiltin(builtin->builtin, sem->StorageClass()))});
-    } else if (auto* location = attr->As<ast::LocationAttribute>()) {
-      push_annot(spv::Op::OpDecorate,
-                 {Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
-                  Operand::Int(location->value)});
-    } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
-      AddInterpolationDecorations(var_id, interpolate->type,
-                                  interpolate->sampling);
-    } else if (attr->Is<ast::InvariantAttribute>()) {
-      push_annot(spv::Op::OpDecorate,
-                 {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
-    } else if (auto* binding = attr->As<ast::BindingAttribute>()) {
-      push_annot(spv::Op::OpDecorate,
-                 {Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
-                  Operand::Int(binding->value)});
-    } else if (auto* group = attr->As<ast::GroupAttribute>()) {
-      push_annot(spv::Op::OpDecorate, {Operand::Int(var_id),
-                                       Operand::Int(SpvDecorationDescriptorSet),
-                                       Operand::Int(group->value)});
-    } else if (attr->Is<ast::OverrideAttribute>()) {
-      // Spec constants are handled elsewhere
-    } else if (!attr->Is<ast::InternalAttribute>()) {
-      error_ = "unknown attribute";
+    bool ok = Switch(
+        attr,
+        [&](const ast::BuiltinAttribute* builtin) {
+          push_annot(spv::Op::OpDecorate,
+                     {Operand::Int(var_id), Operand::Int(SpvDecorationBuiltIn),
+                      Operand::Int(ConvertBuiltin(builtin->builtin,
+                                                  sem->StorageClass()))});
+          return true;
+        },
+        [&](const ast::LocationAttribute* location) {
+          push_annot(spv::Op::OpDecorate,
+                     {Operand::Int(var_id), Operand::Int(SpvDecorationLocation),
+                      Operand::Int(location->value)});
+          return true;
+        },
+        [&](const ast::InterpolateAttribute* interpolate) {
+          AddInterpolationDecorations(var_id, interpolate->type,
+                                      interpolate->sampling);
+          return true;
+        },
+        [&](const ast::InvariantAttribute*) {
+          push_annot(
+              spv::Op::OpDecorate,
+              {Operand::Int(var_id), Operand::Int(SpvDecorationInvariant)});
+          return true;
+        },
+        [&](const ast::BindingAttribute* binding) {
+          push_annot(spv::Op::OpDecorate,
+                     {Operand::Int(var_id), Operand::Int(SpvDecorationBinding),
+                      Operand::Int(binding->value)});
+          return true;
+        },
+        [&](const ast::GroupAttribute* group) {
+          push_annot(
+              spv::Op::OpDecorate,
+              {Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet),
+               Operand::Int(group->value)});
+          return true;
+        },
+        [&](const ast::OverrideAttribute*) {
+          return true;  // Spec constants are handled elsewhere
+        },
+        [&](const ast::InternalAttribute*) {
+          return true;  // ignored
+        },
+        [&](Default) {
+          error_ = "unknown attribute";
+          return false;
+        });
+    if (!ok) {
       return false;
@@ -1123,19 +1150,21 @@
   // promoted to storage with the VarForDynamicIndex transform.
   for (auto* accessor : accessors) {
-    if (auto* array = accessor->As<ast::IndexAccessorExpression>()) {
-      if (!GenerateIndexAccessor(array, &info)) {
-        return 0;
-      }
-    } else if (auto* member = accessor->As<ast::MemberAccessorExpression>()) {
-      if (!GenerateMemberAccessor(member, &info)) {
-        return 0;
-      }
-    } else {
-      error_ =
-          "invalid accessor in list: " + std::string(accessor->TypeInfo().name);
-      return 0;
+    bool ok = Switch(
+        accessor,
+        [&](const ast::IndexAccessorExpression* array) {
+          return GenerateIndexAccessor(array, &info);
+        },
+        [&](const ast::MemberAccessorExpression* member) {
+          return GenerateMemberAccessor(member, &info);
+        },
+        [&](Default) {
+          error_ = "invalid accessor in list: " +
+                   std::string(accessor->TypeInfo().name);
+          return false;
+        });
+    if (!ok) {
+      return false;
@@ -1653,21 +1682,28 @@
     constant.constant_id = global->ConstantId();
-  if (auto* l = lit->As<ast::BoolLiteralExpression>()) {
-    constant.kind = ScalarConstant::Kind::kBool;
-    constant.value.b = l->value;
-  } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
-    constant.kind = ScalarConstant::Kind::kI32;
-    constant.value.i32 = sl->value;
-  } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
-    constant.kind = ScalarConstant::Kind::kU32;
-    constant.value.u32 = ul->value;
-  } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
-    constant.kind = ScalarConstant::Kind::kF32;
-    constant.value.f32 = fl->value;
-  } else {
-    error_ = "unknown literal type";
-    return 0;
+  Switch(
+      lit,
+      [&](const ast::BoolLiteralExpression* l) {
+        constant.kind = ScalarConstant::Kind::kBool;
+        constant.value.b = l->value;
+      },
+      [&](const ast::SintLiteralExpression* sl) {
+        constant.kind = ScalarConstant::Kind::kI32;
+        constant.value.i32 = sl->value;
+      },
+      [&](const ast::UintLiteralExpression* ul) {
+        constant.kind = ScalarConstant::Kind::kU32;
+        constant.value.u32 = ul->value;
+      },
+      [&](const ast::FloatLiteralExpression* fl) {
+        constant.kind = ScalarConstant::Kind::kF32;
+        constant.value.f32 = fl->value;
+      },
+      [&](Default) { error_ = "unknown literal type"; });
+  if (!error_.empty()) {
+    return false;
   return GenerateConstantIfNeeded(constant);
@@ -2209,19 +2245,25 @@
 uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
   auto* call = builder_.Sem().Get(expr);
   auto* target = call->Target();
-  if (auto* func = target->As<sem::Function>()) {
-    return GenerateFunctionCall(call, func);
-  }
-  if (auto* builtin = target->As<sem::Builtin>()) {
-    return GenerateBuiltinCall(call, builtin);
-  }
-  if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
-    return GenerateTypeConstructorOrConversion(call, nullptr);
-  }
-  TINT_ICE(Writer, builder_.Diagnostics())
-      << "unhandled call target: " << target->TypeInfo().name;
-  return false;
+  return Switch(
+      target,
+      [&](const sem::Function* func) {
+        return GenerateFunctionCall(call, func);
+      },
+      [&](const sem::Builtin* builtin) {
+        return GenerateBuiltinCall(call, builtin);
+      },
+      [&](const sem::TypeConversion*) {
+        return GenerateTypeConstructorOrConversion(call, nullptr);
+      },
+      [&](const sem::TypeConstructor*) {
+        return GenerateTypeConstructorOrConversion(call, nullptr);
+      },
+      [&](Default) -> uint32_t {
+        TINT_ICE(Writer, builder_.Diagnostics())
+            << "unhandled call target: " << target->TypeInfo().name;
+        return 0;
+      });
 uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
@@ -3790,46 +3832,49 @@
 bool Builder::GenerateStatement(const ast::Statement* stmt) {
-  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
-    return GenerateAssignStatement(a);
-  }
-  if (auto* b = stmt->As<ast::BlockStatement>()) {
-    return GenerateBlockStatement(b);
-  }
-  if (auto* b = stmt->As<ast::BreakStatement>()) {
-    return GenerateBreakStatement(b);
-  }
-  if (auto* c = stmt->As<ast::CallStatement>()) {
-    return GenerateCallExpression(c->expr) != 0;
-  }
-  if (auto* c = stmt->As<ast::ContinueStatement>()) {
-    return GenerateContinueStatement(c);
-  }
-  if (auto* d = stmt->As<ast::DiscardStatement>()) {
-    return GenerateDiscardStatement(d);
-  }
-  if (stmt->Is<ast::FallthroughStatement>()) {
-    // Do nothing here, the fallthrough gets handled by the switch code.
-    return true;
-  }
-  if (auto* i = stmt->As<ast::IfStatement>()) {
-    return GenerateIfStatement(i);
-  }
-  if (auto* l = stmt->As<ast::LoopStatement>()) {
-    return GenerateLoopStatement(l);
-  }
-  if (auto* r = stmt->As<ast::ReturnStatement>()) {
-    return GenerateReturnStatement(r);
-  }
-  if (auto* s = stmt->As<ast::SwitchStatement>()) {
-    return GenerateSwitchStatement(s);
-  }
-  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
-    return GenerateVariableDeclStatement(v);
-  }
-  error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
-  return false;
+  return Switch(
+      stmt,
+      [&](const ast::AssignmentStatement* a) {
+        return GenerateAssignStatement(a);
+      },
+      [&](const ast::BlockStatement* b) {  //
+        return GenerateBlockStatement(b);
+      },
+      [&](const ast::BreakStatement* b) {  //
+        return GenerateBreakStatement(b);
+      },
+      [&](const ast::CallStatement* c) {
+        return GenerateCallExpression(c->expr) != 0;
+      },
+      [&](const ast::ContinueStatement* c) {
+        return GenerateContinueStatement(c);
+      },
+      [&](const ast::DiscardStatement* d) {
+        return GenerateDiscardStatement(d);
+      },
+      [&](const ast::FallthroughStatement*) {
+        // Do nothing here, the fallthrough gets handled by the switch code.
+        return true;
+      },
+      [&](const ast::IfStatement* i) {  //
+        return GenerateIfStatement(i);
+      },
+      [&](const ast::LoopStatement* l) {  //
+        return GenerateLoopStatement(l);
+      },
+      [&](const ast::ReturnStatement* r) {  //
+        return GenerateReturnStatement(r);
+      },
+      [&](const ast::SwitchStatement* s) {  //
+        return GenerateSwitchStatement(s);
+      },
+      [&](const ast::VariableDeclStatement* v) {
+        return GenerateVariableDeclStatement(v);
+      },
+      [&](Default) {
+        error_ = "Unknown statement: " + std::string(stmt->TypeInfo().name);
+        return false;
+      });
 bool Builder::GenerateVariableDeclStatement(
@@ -3872,78 +3917,91 @@
   return utils::GetOrCreate(type_name_to_id_, type_name, [&]() -> uint32_t {
     auto result = result_op();
     auto id = result.to_i();
-    if (auto* arr = type->As<sem::Array>()) {
-      if (!GenerateArrayType(arr, result)) {
-        return 0;
-      }
-    } else if (type->Is<sem::Bool>()) {
-      push_type(spv::Op::OpTypeBool, {result});
-    } else if (type->Is<sem::F32>()) {
-      push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
-    } else if (type->Is<sem::I32>()) {
-      push_type(spv::Op::OpTypeInt,
-                {result, Operand::Int(32), Operand::Int(1)});
-    } else if (auto* mat = type->As<sem::Matrix>()) {
-      if (!GenerateMatrixType(mat, result)) {
-        return 0;
-      }
-    } else if (auto* ptr = type->As<sem::Pointer>()) {
-      if (!GeneratePointerType(ptr, result)) {
-        return 0;
-      }
-    } else if (auto* ref = type->As<sem::Reference>()) {
-      if (!GenerateReferenceType(ref, result)) {
-        return 0;
-      }
-    } else if (auto* str = type->As<sem::Struct>()) {
-      if (!GenerateStructType(str, result)) {
-        return 0;
-      }
-    } else if (type->Is<sem::U32>()) {
-      push_type(spv::Op::OpTypeInt,
-                {result, Operand::Int(32), Operand::Int(0)});
-    } else if (auto* vec = type->As<sem::Vector>()) {
-      if (!GenerateVectorType(vec, result)) {
-        return 0;
-      }
-    } else if (type->Is<sem::Void>()) {
-      push_type(spv::Op::OpTypeVoid, {result});
-    } else if (auto* tex = type->As<sem::Texture>()) {
-      if (!GenerateTextureType(tex, result)) {
-        return 0;
-      }
+    bool ok = Switch(
+        type,
+        [&](const sem::Array* arr) {  //
+          return GenerateArrayType(arr, result);
+        },
+        [&](const sem::Bool*) {
+          push_type(spv::Op::OpTypeBool, {result});
+          return true;
+        },
+        [&](const sem::F32*) {
+          push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)});
+          return true;
+        },
+        [&](const sem::I32*) {
+          push_type(spv::Op::OpTypeInt,
+                    {result, Operand::Int(32), Operand::Int(1)});
+          return true;
+        },
+        [&](const sem::Matrix* mat) {  //
+          return GenerateMatrixType(mat, result);
+        },
+        [&](const sem::Pointer* ptr) {  //
+          return GeneratePointerType(ptr, result);
+        },
+        [&](const sem::Reference* ref) {  //
+          return GenerateReferenceType(ref, result);
+        },
+        [&](const sem::Struct* str) {  //
+          return GenerateStructType(str, result);
+        },
+        [&](const sem::U32*) {
+          push_type(spv::Op::OpTypeInt,
+                    {result, Operand::Int(32), Operand::Int(0)});
+          return true;
+        },
+        [&](const sem::Vector* vec) {  //
+          return GenerateVectorType(vec, result);
+        },
+        [&](const sem::Void*) {
+          push_type(spv::Op::OpTypeVoid, {result});
+          return true;
+        },
+        [&](const sem::StorageTexture* tex) {
+          if (!GenerateTextureType(tex, result)) {
+            return false;
+          }
-      if (auto* st = tex->As<sem::StorageTexture>()) {
-        // Register all three access types of StorageTexture names. In SPIR-V,
-        // we must output a single type, while the variable is annotated with
-        // the access type. Doing this ensures we de-dupe.
-        type_name_to_id_[builder_
-                             .create<sem::StorageTexture>(
-                                 st->dim(), st->texel_format(),
-                                 ast::Access::kRead, st->type())
-                             ->type_name()] = id;
-        type_name_to_id_[builder_
-                             .create<sem::StorageTexture>(
-                                 st->dim(), st->texel_format(),
-                                 ast::Access::kWrite, st->type())
-                             ->type_name()] = id;
-        type_name_to_id_[builder_
-                             .create<sem::StorageTexture>(
-                                 st->dim(), st->texel_format(),
-                                 ast::Access::kReadWrite, st->type())
-                             ->type_name()] = id;
-      }
+          // Register all three access types of StorageTexture names. In
+          // SPIR-V, we must output a single type, while the variable is
+          // annotated with the access type. Doing this ensures we de-dupe.
+          type_name_to_id_[builder_
+                               .create<sem::StorageTexture>(
+                                   tex->dim(), tex->texel_format(),
+                                   ast::Access::kRead, tex->type())
+                               ->type_name()] = id;
+          type_name_to_id_[builder_
+                               .create<sem::StorageTexture>(
+                                   tex->dim(), tex->texel_format(),
+                                   ast::Access::kWrite, tex->type())
+                               ->type_name()] = id;
+          type_name_to_id_[builder_
+                               .create<sem::StorageTexture>(
+                                   tex->dim(), tex->texel_format(),
+                                   ast::Access::kReadWrite, tex->type())
+                               ->type_name()] = id;
+          return true;
+        },
+        [&](const sem::Texture* tex) {
+          return GenerateTextureType(tex, result);
+        },
+        [&](const sem::Sampler*) {
+          push_type(spv::Op::OpTypeSampler, {result});
-    } else if (type->Is<sem::Sampler>()) {
-      push_type(spv::Op::OpTypeSampler, {result});
+          // Register both of the sampler type names. In SPIR-V they're the same
+          // sampler type, so we need to match that when we do the dedup check.
+          type_name_to_id_["__sampler_sampler"] = id;
+          type_name_to_id_["__sampler_comparison"] = id;
+          return true;
+        },
+        [&](Default) {
+          error_ = "unable to convert type: " + type->type_name();
+          return false;
+        });
-      // Register both of the sampler type names. In SPIR-V they're the same
-      // sampler type, so we need to match that when we do the dedup check.
-      type_name_to_id_["__sampler_sampler"] = id;
-      type_name_to_id_["__sampler_comparison"] = id;
-    } else {
-      error_ = "unable to convert type: " + type->type_name();
+    if (!ok) {
       return 0;
@@ -3995,22 +4053,31 @@
   if (dim == ast::TextureDimension::kCubeArray) {
-    if (texture->Is<sem::SampledTexture>() ||
-        texture->Is<sem::DepthTexture>()) {
+    if (texture->IsAnyOf<sem::SampledTexture, sem::DepthTexture>()) {
-  uint32_t type_id = 0u;
-  if (texture->IsAnyOf<sem::DepthTexture, sem::DepthMultisampledTexture>()) {
-    type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
-  } else if (auto* s = texture->As<sem::SampledTexture>()) {
-    type_id = GenerateTypeIfNeeded(s->type());
-  } else if (auto* ms = texture->As<sem::MultisampledTexture>()) {
-    type_id = GenerateTypeIfNeeded(ms->type());
-  } else if (auto* st = texture->As<sem::StorageTexture>()) {
-    type_id = GenerateTypeIfNeeded(st->type());
-  }
+  uint32_t type_id = Switch(
+      texture,
+      [&](const sem::DepthTexture*) {
+        return GenerateTypeIfNeeded(builder_.create<sem::F32>());
+      },
+      [&](const sem::DepthMultisampledTexture*) {
+        return GenerateTypeIfNeeded(builder_.create<sem::F32>());
+      },
+      [&](const sem::SampledTexture* t) {
+        return GenerateTypeIfNeeded(t->type());
+      },
+      [&](const sem::MultisampledTexture* t) {
+        return GenerateTypeIfNeeded(t->type());
+      },
+      [&](const sem::StorageTexture* t) {
+        return GenerateTypeIfNeeded(t->type());
+      },
+      [&](Default) -> uint32_t {  //
+        return 0u;
+      });
   if (type_id == 0u) {
     return false;
diff --git a/src/writer/wgsl/ b/src/writer/wgsl/
index 094e889..96a459a 100644
--- a/src/writer/wgsl/
+++ b/src/writer/wgsl/
@@ -68,23 +68,17 @@
 bool GeneratorImpl::Generate() {
   // Generate global declarations in the order they appear in the module.
   for (auto* decl : program_->AST().GlobalDeclarations()) {
-    if (auto* td = decl->As<ast::TypeDecl>()) {
-      if (!EmitTypeDecl(td)) {
-        return false;
-      }
-    } else if (auto* func = decl->As<ast::Function>()) {
-      if (!EmitFunction(func)) {
-        return false;
-      }
-    } else if (auto* var = decl->As<ast::Variable>()) {
-      if (!EmitVariable(line(), var)) {
-        return false;
-      }
-    } else {
-      TINT_UNREACHABLE(Writer, diagnostics_);
+    if (!Switch(
+            decl,  //
+            [&](const ast::TypeDecl* td) { return EmitTypeDecl(td); },
+            [&](const ast::Function* func) { return EmitFunction(func); },
+            [&](const ast::Variable* var) { return EmitVariable(line(), var); },
+            [&](Default) {
+              TINT_UNREACHABLE(Writer, diagnostics_);
+              return false;
+            })) {
       return false;
     if (decl != program_->AST().GlobalDeclarations().back()) {
@@ -94,59 +88,64 @@
 bool GeneratorImpl::EmitTypeDecl(const ast::TypeDecl* ty) {
-  if (auto* alias = ty->As<ast::Alias>()) {
-    auto out = line();
-    out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
-    if (!EmitType(out, alias->type)) {
-      return false;
-    }
-    out << ";";
-  } else if (auto* str = ty->As<ast::Struct>()) {
-    if (!EmitStructType(str)) {
-      return false;
-    }
-  } else {
-    diagnostics_.add_error(
-        diag::System::Writer,
-        "unknown declared type: " + std::string(ty->TypeInfo().name));
-    return false;
-  }
-  return true;
+  return Switch(
+      ty,
+      [&](const ast::Alias* alias) {  //
+        auto out = line();
+        out << "type " << program_->Symbols().NameFor(alias->name) << " = ";
+        if (!EmitType(out, alias->type)) {
+          return false;
+        }
+        out << ";";
+        return true;
+      },
+      [&](const ast::Struct* str) {  //
+        return EmitStructType(str);
+      },
+      [&](Default) {  //
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "unknown declared type: " + std::string(ty->TypeInfo().name));
+        return false;
+      });
 bool GeneratorImpl::EmitExpression(std::ostream& out,
                                    const ast::Expression* expr) {
-  if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
-    return EmitIndexAccessor(out, a);
-  }
-  if (auto* b = expr->As<ast::BinaryExpression>()) {
-    return EmitBinary(out, b);
-  }
-  if (auto* b = expr->As<ast::BitcastExpression>()) {
-    return EmitBitcast(out, b);
-  }
-  if (auto* c = expr->As<ast::CallExpression>()) {
-    return EmitCall(out, c);
-  }
-  if (auto* i = expr->As<ast::IdentifierExpression>()) {
-    return EmitIdentifier(out, i);
-  }
-  if (auto* l = expr->As<ast::LiteralExpression>()) {
-    return EmitLiteral(out, l);
-  }
-  if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
-    return EmitMemberAccessor(out, m);
-  }
-  if (expr->Is<ast::PhonyExpression>()) {
-    out << "_";
-    return true;
-  }
-  if (auto* u = expr->As<ast::UnaryOpExpression>()) {
-    return EmitUnaryOp(out, u);
-  }
-  diagnostics_.add_error(diag::System::Writer, "unknown expression type");
-  return false;
+  return Switch(
+      expr,
+      [&](const ast::IndexAccessorExpression* a) {  //
+        return EmitIndexAccessor(out, a);
+      },
+      [&](const ast::BinaryExpression* b) {  //
+        return EmitBinary(out, b);
+      },
+      [&](const ast::BitcastExpression* b) {  //
+        return EmitBitcast(out, b);
+      },
+      [&](const ast::CallExpression* c) {  //
+        return EmitCall(out, c);
+      },
+      [&](const ast::IdentifierExpression* i) {  //
+        return EmitIdentifier(out, i);
+      },
+      [&](const ast::LiteralExpression* l) {  //
+        return EmitLiteral(out, l);
+      },
+      [&](const ast::MemberAccessorExpression* m) {  //
+        return EmitMemberAccessor(out, m);
+      },
+      [&](const ast::PhonyExpression*) {  //
+        out << "_";
+        return true;
+      },
+      [&](const ast::UnaryOpExpression* u) {  //
+        return EmitUnaryOp(out, u);
+      },
+      [&](Default) {
+        diagnostics_.add_error(diag::System::Writer, "unknown expression type");
+        return false;
+      });
 bool GeneratorImpl::EmitIndexAccessor(
@@ -250,19 +249,28 @@
 bool GeneratorImpl::EmitLiteral(std::ostream& out,
                                 const ast::LiteralExpression* lit) {
-  if (auto* bl = lit->As<ast::BoolLiteralExpression>()) {
-    out << (bl->value ? "true" : "false");
-  } else if (auto* fl = lit->As<ast::FloatLiteralExpression>()) {
-    out << FloatToBitPreservingString(fl->value);
-  } else if (auto* sl = lit->As<ast::SintLiteralExpression>()) {
-    out << sl->value;
-  } else if (auto* ul = lit->As<ast::UintLiteralExpression>()) {
-    out << ul->value << "u";
-  } else {
-    diagnostics_.add_error(diag::System::Writer, "unknown literal type");
-    return false;
-  }
-  return true;
+  return Switch(
+      lit,
+      [&](const ast::BoolLiteralExpression* bl) {  //
+        out << (bl->value ? "true" : "false");
+        return true;
+      },
+      [&](const ast::FloatLiteralExpression* fl) {  //
+        out << FloatToBitPreservingString(fl->value);
+        return true;
+      },
+      [&](const ast::SintLiteralExpression* sl) {  //
+        out << sl->value;
+        return true;
+      },
+      [&](const ast::UintLiteralExpression* ul) {  //
+        out << ul->value << "u";
+        return true;
+      },
+      [&](Default) {  //
+        diagnostics_.add_error(diag::System::Writer, "unknown literal type");
+        return false;
+      });
 bool GeneratorImpl::EmitIdentifier(std::ostream& out,
@@ -366,155 +374,208 @@
 bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) {
-  if (auto* ary = ty->As<ast::Array>()) {
-    for (auto* attr : ary->attributes) {
-      if (auto* stride = attr->As<ast::StrideAttribute>()) {
-        out << "@stride(" << stride->stride << ") ";
-      }
-    }
+  return Switch(
+      ty,
+      [&](const ast::Array* ary) {
+        for (auto* attr : ary->attributes) {
+          if (auto* stride = attr->As<ast::StrideAttribute>()) {
+            out << "@stride(" << stride->stride << ") ";
+          }
+        }
-    out << "array<";
-    if (!EmitType(out, ary->type)) {
-      return false;
-    }
+        out << "array<";
+        if (!EmitType(out, ary->type)) {
+          return false;
+        }
-    if (!ary->IsRuntimeArray()) {
-      out << ", ";
-      if (!EmitExpression(out, ary->count)) {
-        return false;
-      }
-    }
+        if (!ary->IsRuntimeArray()) {
+          out << ", ";
+          if (!EmitExpression(out, ary->count)) {
+            return false;
+          }
+        }
-    out << ">";
-  } else if (ty->Is<ast::Bool>()) {
-    out << "bool";
-  } else if (ty->Is<ast::F32>()) {
-    out << "f32";
-  } else if (ty->Is<ast::I32>()) {
-    out << "i32";
-  } else if (auto* mat = ty->As<ast::Matrix>()) {
-    out << "mat" << mat->columns << "x" << mat->rows;
-    if (auto* el_ty = mat->type) {
-      out << "<";
-      if (!EmitType(out, el_ty)) {
-        return false;
-      }
-      out << ">";
-    }
-  } else if (auto* ptr = ty->As<ast::Pointer>()) {
-    out << "ptr<" << ptr->storage_class << ", ";
-    if (!EmitType(out, ptr->type)) {
-      return false;
-    }
-    if (ptr->access != ast::Access::kUndefined) {
-      out << ", ";
-      if (!EmitAccess(out, ptr->access)) {
-        return false;
-      }
-    }
-    out << ">";
-  } else if (auto* atomic = ty->As<ast::Atomic>()) {
-    out << "atomic<";
-    if (!EmitType(out, atomic->type)) {
-      return false;
-    }
-    out << ">";
-  } else if (auto* sampler = ty->As<ast::Sampler>()) {
-    out << "sampler";
+        out << ">";
+        return true;
+      },
+      [&](const ast::Bool*) {
+        out << "bool";
+        return true;
+      },
+      [&](const ast::F32*) {
+        out << "f32";
+        return true;
+      },
+      [&](const ast::I32*) {
+        out << "i32";
+        return true;
+      },
+      [&](const ast::Matrix* mat) {
+        out << "mat" << mat->columns << "x" << mat->rows;
+        if (auto* el_ty = mat->type) {
+          out << "<";
+          if (!EmitType(out, el_ty)) {
+            return false;
+          }
+          out << ">";
+        }
+        return true;
+      },
+      [&](const ast::Pointer* ptr) {
+        out << "ptr<" << ptr->storage_class << ", ";
+        if (!EmitType(out, ptr->type)) {
+          return false;
+        }
+        if (ptr->access != ast::Access::kUndefined) {
+          out << ", ";
+          if (!EmitAccess(out, ptr->access)) {
+            return false;
+          }
+        }
+        out << ">";
+        return true;
+      },
+      [&](const ast::Atomic* atomic) {
+        out << "atomic<";
+        if (!EmitType(out, atomic->type)) {
+          return false;
+        }
+        out << ">";
+        return true;
+      },
+      [&](const ast::Sampler* sampler) {
+        out << "sampler";
-    if (sampler->IsComparison()) {
-      out << "_comparison";
-    }
-  } else if (ty->Is<ast::ExternalTexture>()) {
-    out << "texture_external";
-  } else if (auto* texture = ty->As<ast::Texture>()) {
-    out << "texture_";
-    if (texture->Is<ast::DepthTexture>()) {
-      out << "depth_";
-    } else if (texture->Is<ast::DepthMultisampledTexture>()) {
-      out << "depth_multisampled_";
-    } else if (texture->Is<ast::SampledTexture>()) {
-      /* nothing to emit */
-    } else if (texture->Is<ast::MultisampledTexture>()) {
-      out << "multisampled_";
-    } else if (texture->Is<ast::StorageTexture>()) {
-      out << "storage_";
-    } else {
-      diagnostics_.add_error(diag::System::Writer, "unknown texture type");
-      return false;
-    }
+        if (sampler->IsComparison()) {
+          out << "_comparison";
+        }
+        return true;
+      },
+      [&](const ast::ExternalTexture*) {
+        out << "texture_external";
+        return true;
+      },
+      [&](const ast::Texture* texture) {
+        out << "texture_";
+        bool ok = Switch(
+            texture,
+            [&](const ast::DepthTexture*) {  //
+              out << "depth_";
+              return true;
+            },
+            [&](const ast::DepthMultisampledTexture*) {  //
+              out << "depth_multisampled_";
+              return true;
+            },
+            [&](const ast::SampledTexture*) {  //
+              /* nothing to emit */
+              return true;
+            },
+            [&](const ast::MultisampledTexture*) {  //
+              out << "multisampled_";
+              return true;
+            },
+            [&](const ast::StorageTexture*) {  //
+              out << "storage_";
+              return true;
+            },
+            [&](Default) {  //
+              diagnostics_.add_error(diag::System::Writer,
+                                     "unknown texture type");
+              return false;
+            });
+        if (!ok) {
+          return false;
+        }
-    switch (texture->dim) {
-      case ast::TextureDimension::k1d:
-        out << "1d";
-        break;
-      case ast::TextureDimension::k2d:
-        out << "2d";
-        break;
-      case ast::TextureDimension::k2dArray:
-        out << "2d_array";
-        break;
-      case ast::TextureDimension::k3d:
-        out << "3d";
-        break;
-      case ast::TextureDimension::kCube:
-        out << "cube";
-        break;
-      case ast::TextureDimension::kCubeArray:
-        out << "cube_array";
-        break;
-      default:
-        diagnostics_.add_error(diag::System::Writer,
-                               "unknown texture dimension");
-        return false;
-    }
+        switch (texture->dim) {
+          case ast::TextureDimension::k1d:
+            out << "1d";
+            break;
+          case ast::TextureDimension::k2d:
+            out << "2d";
+            break;
+          case ast::TextureDimension::k2dArray:
+            out << "2d_array";
+            break;
+          case ast::TextureDimension::k3d:
+            out << "3d";
+            break;
+          case ast::TextureDimension::kCube:
+            out << "cube";
+            break;
+          case ast::TextureDimension::kCubeArray:
+            out << "cube_array";
+            break;
+          default:
+            diagnostics_.add_error(diag::System::Writer,
+                                   "unknown texture dimension");
+            return false;
+        }
-    if (auto* sampled = texture->As<ast::SampledTexture>()) {
-      out << "<";
-      if (!EmitType(out, sampled->type)) {
+        return Switch(
+            texture,
+            [&](const ast::SampledTexture* sampled) {  //
+              out << "<";
+              if (!EmitType(out, sampled->type)) {
+                return false;
+              }
+              out << ">";
+              return true;
+            },
+            [&](const ast::MultisampledTexture* ms) {  //
+              out << "<";
+              if (!EmitType(out, ms->type)) {
+                return false;
+              }
+              out << ">";
+              return true;
+            },
+            [&](const ast::StorageTexture* storage) {  //
+              out << "<";
+              if (!EmitImageFormat(out, storage->format)) {
+                return false;
+              }
+              out << ", ";
+              if (!EmitAccess(out, storage->access)) {
+                return false;
+              }
+              out << ">";
+              return true;
+            },
+            [&](Default) {  //
+              return true;
+            });
+      },
+      [&](const ast::U32*) {
+        out << "u32";
+        return true;
+      },
+      [&](const ast::Vector* vec) {
+        out << "vec" << vec->width;
+        if (auto* el_ty = vec->type) {
+          out << "<";
+          if (!EmitType(out, el_ty)) {
+            return false;
+          }
+          out << ">";
+        }
+        return true;
+      },
+      [&](const ast::Void*) {
+        out << "void";
+        return true;
+      },
+      [&](const ast::TypeName* tn) {
+        out << program_->Symbols().NameFor(tn->name);
+        return true;
+      },
+      [&](Default) {
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "unknown type in EmitType: " + std::string(ty->TypeInfo().name));
         return false;
-      }
-      out << ">";
-    } else if (auto* ms = texture->As<ast::MultisampledTexture>()) {
-      out << "<";
-      if (!EmitType(out, ms->type)) {
-        return false;
-      }
-      out << ">";
-    } else if (auto* storage = texture->As<ast::StorageTexture>()) {
-      out << "<";
-      if (!EmitImageFormat(out, storage->format)) {
-        return false;
-      }
-      out << ", ";
-      if (!EmitAccess(out, storage->access)) {
-        return false;
-      }
-      out << ">";
-    }
-  } else if (ty->Is<ast::U32>()) {
-    out << "u32";
-  } else if (auto* vec = ty->As<ast::Vector>()) {
-    out << "vec" << vec->width;
-    if (auto* el_ty = vec->type) {
-      out << "<";
-      if (!EmitType(out, el_ty)) {
-        return false;
-      }
-      out << ">";
-    }
-  } else if (ty->Is<ast::Void>()) {
-    out << "void";
-  } else if (auto* tn = ty->As<ast::TypeName>()) {
-    out << program_->Symbols().NameFor(tn->name);
-  } else {
-    diagnostics_.add_error(
-        diag::System::Writer,
-        "unknown type in EmitType: " + std::string(ty->TypeInfo().name));
-    return false;
-  }
-  return true;
+      });
 bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
@@ -632,56 +693,90 @@
     first = false;
     out << "@";
-    if (auto* workgroup = attr->As<ast::WorkgroupAttribute>()) {
-      auto values = workgroup->Values();
-      out << "workgroup_size(";
-      for (int i = 0; i < 3; i++) {
-        if (values[i]) {
-          if (i > 0) {
-            out << ", ";
+    bool ok = Switch(
+        attr,
+        [&](const ast::WorkgroupAttribute* workgroup) {
+          auto values = workgroup->Values();
+          out << "workgroup_size(";
+          for (int i = 0; i < 3; i++) {
+            if (values[i]) {
+              if (i > 0) {
+                out << ", ";
+              }
+              if (!EmitExpression(out, values[i])) {
+                return false;
+              }
+            }
-          if (!EmitExpression(out, values[i])) {
-            return false;
+          out << ")";
+          return true;
+        },
+        [&](const ast::StructBlockAttribute*) {  //
+          out << "block";
+          return true;
+        },
+        [&](const ast::StageAttribute* stage) {
+          out << "stage(" << stage->stage << ")";
+          return true;
+        },
+        [&](const ast::BindingAttribute* binding) {
+          out << "binding(" << binding->value << ")";
+          return true;
+        },
+        [&](const ast::GroupAttribute* group) {
+          out << "group(" << group->value << ")";
+          return true;
+        },
+        [&](const ast::LocationAttribute* location) {
+          out << "location(" << location->value << ")";
+          return true;
+        },
+        [&](const ast::BuiltinAttribute* builtin) {
+          out << "builtin(" << builtin->builtin << ")";
+          return true;
+        },
+        [&](const ast::InterpolateAttribute* interpolate) {
+          out << "interpolate(" << interpolate->type;
+          if (interpolate->sampling != ast::InterpolationSampling::kNone) {
+            out << ", " << interpolate->sampling;
-        }
-      }
-      out << ")";
-    } else if (attr->Is<ast::StructBlockAttribute>()) {
-      out << "block";
-    } else if (auto* stage = attr->As<ast::StageAttribute>()) {
-      out << "stage(" << stage->stage << ")";
-    } else if (auto* binding = attr->As<ast::BindingAttribute>()) {
-      out << "binding(" << binding->value << ")";
-    } else if (auto* group = attr->As<ast::GroupAttribute>()) {
-      out << "group(" << group->value << ")";
-    } else if (auto* location = attr->As<ast::LocationAttribute>()) {
-      out << "location(" << location->value << ")";
-    } else if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
-      out << "builtin(" << builtin->builtin << ")";
-    } else if (auto* interpolate = attr->As<ast::InterpolateAttribute>()) {
-      out << "interpolate(" << interpolate->type;
-      if (interpolate->sampling != ast::InterpolationSampling::kNone) {
-        out << ", " << interpolate->sampling;
-      }
-      out << ")";
-    } else if (attr->Is<ast::InvariantAttribute>()) {
-      out << "invariant";
-    } else if (auto* override_attr = attr->As<ast::OverrideAttribute>()) {
-      out << "override";
-      if (override_attr->has_value) {
-        out << "(" << override_attr->value << ")";
-      }
-    } else if (auto* size = attr->As<ast::StructMemberSizeAttribute>()) {
-      out << "size(" << size->size << ")";
-    } else if (auto* align = attr->As<ast::StructMemberAlignAttribute>()) {
-      out << "align(" << align->align << ")";
-    } else if (auto* stride = attr->As<ast::StrideAttribute>()) {
-      out << "stride(" << stride->stride << ")";
-    } else if (auto* internal = attr->As<ast::InternalAttribute>()) {
-      out << "internal(" << internal->InternalName() << ")";
-    } else {
-      TINT_ICE(Writer, diagnostics_)
-          << "Unsupported attribute '" << attr->TypeInfo().name << "'";
+          out << ")";
+          return true;
+        },
+        [&](const ast::InvariantAttribute*) {
+          out << "invariant";
+          return true;
+        },
+        [&](const ast::OverrideAttribute* override_deco) {
+          out << "override";
+          if (override_deco->has_value) {
+            out << "(" << override_deco->value << ")";
+          }
+          return true;
+        },
+        [&](const ast::StructMemberSizeAttribute* size) {
+          out << "size(" << size->size << ")";
+          return true;
+        },
+        [&](const ast::StructMemberAlignAttribute* align) {
+          out << "align(" << align->align << ")";
+          return true;
+        },
+        [&](const ast::StrideAttribute* stride) {
+          out << "stride(" << stride->stride << ")";
+          return true;
+        },
+        [&](const ast::InternalAttribute* internal) {
+          out << "internal(" << internal->InternalName() << ")";
+          return true;
+        },
+        [&](Default) {
+          TINT_ICE(Writer, diagnostics_)
+              << "Unsupported attribute '" << attr->TypeInfo().name << "'";
+          return false;
+        });
+    if (!ok) {
       return false;
@@ -809,55 +904,36 @@
 bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
-  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
-    return EmitAssign(a);
-  }
-  if (auto* b = stmt->As<ast::BlockStatement>()) {
-    return EmitBlock(b);
-  }
-  if (auto* b = stmt->As<ast::BreakStatement>()) {
-    return EmitBreak(b);
-  }
-  if (auto* c = stmt->As<ast::CallStatement>()) {
-    auto out = line();
-    if (!EmitCall(out, c->expr)) {
-      return false;
-    }
-    out << ";";
-    return true;
-  }
-  if (auto* c = stmt->As<ast::ContinueStatement>()) {
-    return EmitContinue(c);
-  }
-  if (auto* d = stmt->As<ast::DiscardStatement>()) {
-    return EmitDiscard(d);
-  }
-  if (auto* f = stmt->As<ast::FallthroughStatement>()) {
-    return EmitFallthrough(f);
-  }
-  if (auto* i = stmt->As<ast::IfStatement>()) {
-    return EmitIf(i);
-  }
-  if (auto* l = stmt->As<ast::LoopStatement>()) {
-    return EmitLoop(l);
-  }
-  if (auto* l = stmt->As<ast::ForLoopStatement>()) {
-    return EmitForLoop(l);
-  }
-  if (auto* r = stmt->As<ast::ReturnStatement>()) {
-    return EmitReturn(r);
-  }
-  if (auto* s = stmt->As<ast::SwitchStatement>()) {
-    return EmitSwitch(s);
-  }
-  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
-    return EmitVariable(line(), v->variable);
-  }
-  diagnostics_.add_error(
-      diag::System::Writer,
-      "unknown statement type: " + std::string(stmt->TypeInfo().name));
-  return false;
+  return Switch(
+      stmt,  //
+      [&](const ast::AssignmentStatement* a) { return EmitAssign(a); },
+      [&](const ast::BlockStatement* b) { return EmitBlock(b); },
+      [&](const ast::BreakStatement* b) { return EmitBreak(b); },
+      [&](const ast::CallStatement* c) {
+        auto out = line();
+        if (!EmitCall(out, c->expr)) {
+          return false;
+        }
+        out << ";";
+        return true;
+      },
+      [&](const ast::ContinueStatement* c) { return EmitContinue(c); },
+      [&](const ast::DiscardStatement* d) { return EmitDiscard(d); },
+      [&](const ast::FallthroughStatement* f) { return EmitFallthrough(f); },
+      [&](const ast::IfStatement* i) { return EmitIf(i); },
+      [&](const ast::LoopStatement* l) { return EmitLoop(l); },
+      [&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
+      [&](const ast::ReturnStatement* r) { return EmitReturn(r); },
+      [&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
+      [&](const ast::VariableDeclStatement* v) {
+        return EmitVariable(line(), v->variable);
+      },
+      [&](Default) {
+        diagnostics_.add_error(
+            diag::System::Writer,
+            "unknown statement type: " + std::string(stmt->TypeInfo().name));
+        return false;
+      });
 bool GeneratorImpl::EmitStatements(const ast::StatementList& stmts) {