resolver: Use Switch() for type-dispatch

Bug: tint:1383
Change-Id: I9efbe6b3e7c0314a76f65b5e8969f1f20bcecf93
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/79771
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 623af60..66a6c6c 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -119,24 +119,23 @@
 
   // Process all module-scope declarations in dependency order.
   for (auto* decl : dependencies_.ordered_globals) {
-    if (auto* td = decl->As<ast::TypeDecl>()) {
-      Mark(td);
-      if (!TypeDecl(td)) {
-        return false;
-      }
-    } else if (auto* func = decl->As<ast::Function>()) {
-      Mark(func);
-      if (!Function(func)) {
-        return false;
-      }
-    } else if (auto* var = decl->As<ast::Variable>()) {
-      Mark(var);
-      if (!GlobalVariable(var)) {
-        return false;
-      }
-    } else {
-      TINT_UNREACHABLE(Resolver, diagnostics_)
-          << "unhandled global declaration: " << decl->TypeInfo().name;
+    Mark(decl);
+    if (!Switch(
+            decl,                           //
+            [&](const ast::TypeDecl* td) {  //
+              return TypeDecl(td) != nullptr;
+            },
+            [&](const ast::Function* func) {
+              return Function(func) != nullptr;
+            },
+            [&](const ast::Variable* var) {
+              return GlobalVariable(var) != nullptr;
+            },
+            [&](Default) {
+              TINT_UNREACHABLE(Resolver, diagnostics_)
+                  << "unhandled global declaration: " << decl->TypeInfo().name;
+              return false;
+            })) {
       return false;
     }
   }
@@ -165,131 +164,137 @@
 
 sem::Type* Resolver::Type(const ast::Type* ty) {
   Mark(ty);
-  auto* s = [&]() -> sem::Type* {
-    if (ty->Is<ast::Void>()) {
-      return builder_->create<sem::Void>();
-    }
-    if (ty->Is<ast::Bool>()) {
-      return builder_->create<sem::Bool>();
-    }
-    if (ty->Is<ast::I32>()) {
-      return builder_->create<sem::I32>();
-    }
-    if (ty->Is<ast::U32>()) {
-      return builder_->create<sem::U32>();
-    }
-    if (ty->Is<ast::F32>()) {
-      return builder_->create<sem::F32>();
-    }
-    if (auto* t = ty->As<ast::Vector>()) {
-      if (!t->type) {
-        AddError("missing vector element type", t->source.End());
-        return nullptr;
-      }
-      if (auto* el = Type(t->type)) {
-        if (auto* vector = builder_->create<sem::Vector>(el, t->width)) {
-          if (ValidateVector(vector, t->source)) {
-            return vector;
-          }
+  auto* s = Switch(
+      ty,
+      [&](const ast::Void*) -> sem::Type* {
+        return builder_->create<sem::Void>();
+      },
+      [&](const ast::Bool*) -> sem::Type* {
+        return builder_->create<sem::Bool>();
+      },
+      [&](const ast::I32*) -> sem::Type* {
+        return builder_->create<sem::I32>();
+      },
+      [&](const ast::U32*) -> sem::Type* {
+        return builder_->create<sem::U32>();
+      },
+      [&](const ast::F32*) -> sem::Type* {
+        return builder_->create<sem::F32>();
+      },
+      [&](const ast::Vector* t) -> sem::Type* {
+        if (!t->type) {
+          AddError("missing vector element type", t->source.End());
+          return nullptr;
         }
-      }
-      return nullptr;
-    }
-    if (auto* t = ty->As<ast::Matrix>()) {
-      if (!t->type) {
-        AddError("missing matrix element type", t->source.End());
-        return nullptr;
-      }
-      if (auto* el = Type(t->type)) {
-        if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) {
-          if (auto* matrix =
-                  builder_->create<sem::Matrix>(column_type, t->columns)) {
-            if (ValidateMatrix(matrix, t->source)) {
-              return matrix;
+        if (auto* el = Type(t->type)) {
+          if (auto* vector = builder_->create<sem::Vector>(el, t->width)) {
+            if (ValidateVector(vector, t->source)) {
+              return vector;
             }
           }
         }
-      }
-      return nullptr;
-    }
-    if (auto* t = ty->As<ast::Array>()) {
-      return Array(t);
-    }
-    if (auto* t = ty->As<ast::Atomic>()) {
-      if (auto* el = Type(t->type)) {
-        auto* a = builder_->create<sem::Atomic>(el);
-        if (!ValidateAtomic(t, a)) {
+        return nullptr;
+      },
+      [&](const ast::Matrix* t) -> sem::Type* {
+        if (!t->type) {
+          AddError("missing matrix element type", t->source.End());
           return nullptr;
         }
-        return a;
-      }
-      return nullptr;
-    }
-    if (auto* t = ty->As<ast::Pointer>()) {
-      if (auto* el = Type(t->type)) {
-        auto access = t->access;
-        if (access == ast::kUndefined) {
-          access = DefaultAccessForStorageClass(t->storage_class);
+        if (auto* el = Type(t->type)) {
+          if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) {
+            if (auto* matrix =
+                    builder_->create<sem::Matrix>(column_type, t->columns)) {
+              if (ValidateMatrix(matrix, t->source)) {
+                return matrix;
+              }
+            }
+          }
         }
-        return builder_->create<sem::Pointer>(el, t->storage_class, access);
-      }
-      return nullptr;
-    }
-    if (auto* t = ty->As<ast::Sampler>()) {
-      return builder_->create<sem::Sampler>(t->kind);
-    }
-    if (auto* t = ty->As<ast::SampledTexture>()) {
-      if (auto* el = Type(t->type)) {
-        return builder_->create<sem::SampledTexture>(t->dim, el);
-      }
-      return nullptr;
-    }
-    if (auto* t = ty->As<ast::MultisampledTexture>()) {
-      if (auto* el = Type(t->type)) {
-        return builder_->create<sem::MultisampledTexture>(t->dim, el);
-      }
-      return nullptr;
-    }
-    if (auto* t = ty->As<ast::DepthTexture>()) {
-      return builder_->create<sem::DepthTexture>(t->dim);
-    }
-    if (auto* t = ty->As<ast::DepthMultisampledTexture>()) {
-      return builder_->create<sem::DepthMultisampledTexture>(t->dim);
-    }
-    if (auto* t = ty->As<ast::StorageTexture>()) {
-      if (auto* el = Type(t->type)) {
-        if (!ValidateStorageTexture(t)) {
-          return nullptr;
+        return nullptr;
+      },
+      [&](const ast::Array* t) -> sem::Type* { return Array(t); },
+      [&](const ast::Atomic* t) -> sem::Type* {
+        if (auto* el = Type(t->type)) {
+          auto* a = builder_->create<sem::Atomic>(el);
+          if (!ValidateAtomic(t, a)) {
+            return nullptr;
+          }
+          return a;
         }
-        return builder_->create<sem::StorageTexture>(t->dim, t->format,
-                                                     t->access, el);
-      }
-      return nullptr;
-    }
-    if (ty->As<ast::ExternalTexture>()) {
-      return builder_->create<sem::ExternalTexture>();
-    }
-    return Switch(
-        ResolvedSymbol(ty),  //
-        [&](sem::Type* type) { return type; },
-        [&](sem::Variable* var) {
-          auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
-          AddError("cannot use variable '" + name + "' as type", ty->source);
-          AddNote("'" + name + "' declared here", var->Declaration()->source);
-          return nullptr;
-        },
-        [&](sem::Function* func) {
-          auto name = builder_->Symbols().NameFor(func->Declaration()->symbol);
-          AddError("cannot use function '" + name + "' as type", ty->source);
-          AddNote("'" + name + "' declared here", func->Declaration()->source);
-          return nullptr;
-        },
-        [&](Default) {
-          TINT_UNREACHABLE(Resolver, diagnostics_)
-              << "Unhandled ast::Type: " << ty->TypeInfo().name;
-          return nullptr;
-        });
-  }();
+        return nullptr;
+      },
+      [&](const ast::Pointer* t) -> sem::Type* {
+        if (auto* el = Type(t->type)) {
+          auto access = t->access;
+          if (access == ast::kUndefined) {
+            access = DefaultAccessForStorageClass(t->storage_class);
+          }
+          return builder_->create<sem::Pointer>(el, t->storage_class, access);
+        }
+        return nullptr;
+      },
+      [&](const ast::Sampler* t) -> sem::Type* {
+        return builder_->create<sem::Sampler>(t->kind);
+      },
+      [&](const ast::SampledTexture* t) -> sem::Type* {
+        if (auto* el = Type(t->type)) {
+          return builder_->create<sem::SampledTexture>(t->dim, el);
+        }
+        return nullptr;
+      },
+      [&](const ast::MultisampledTexture* t) -> sem::Type* {
+        if (auto* el = Type(t->type)) {
+          return builder_->create<sem::MultisampledTexture>(t->dim, el);
+        }
+        return nullptr;
+      },
+      [&](const ast::DepthTexture* t) -> sem::Type* {
+        return builder_->create<sem::DepthTexture>(t->dim);
+      },
+      [&](const ast::DepthMultisampledTexture* t) -> sem::Type* {
+        return builder_->create<sem::DepthMultisampledTexture>(t->dim);
+      },
+      [&](const ast::StorageTexture* t) -> sem::Type* {
+        if (auto* el = Type(t->type)) {
+          if (!ValidateStorageTexture(t)) {
+            return nullptr;
+          }
+          return builder_->create<sem::StorageTexture>(t->dim, t->format,
+                                                       t->access, el);
+        }
+        return nullptr;
+      },
+      [&](const ast::ExternalTexture*) -> sem::Type* {
+        return builder_->create<sem::ExternalTexture>();
+      },
+      [&](Default) -> sem::Type* {
+        return Switch(
+            ResolvedSymbol(ty),  //
+            [&](sem::Type* type) { return type; },
+            [&](sem::Variable* var) {
+              auto name =
+                  builder_->Symbols().NameFor(var->Declaration()->symbol);
+              AddError("cannot use variable '" + name + "' as type",
+                       ty->source);
+              AddNote("'" + name + "' declared here",
+                      var->Declaration()->source);
+              return nullptr;
+            },
+            [&](sem::Function* func) {
+              auto name =
+                  builder_->Symbols().NameFor(func->Declaration()->symbol);
+              AddError("cannot use function '" + name + "' as type",
+                       ty->source);
+              AddNote("'" + name + "' declared here",
+                      func->Declaration()->source);
+              return nullptr;
+            },
+            [&](Default) {
+              TINT_UNREACHABLE(Resolver, diagnostics_)
+                  << "Unhandled ast::Type: " << ty->TypeInfo().name;
+              return nullptr;
+            });
+      });
 
   if (s) {
     builder_->Sem().Add(ty, s);
@@ -520,30 +525,27 @@
 
 void Resolver::SetShadows() {
   for (auto it : dependencies_.shadows) {
-    auto* var = Sem(it.first);
-    if (auto* local = var->As<sem::LocalVariable>()) {
-      local->SetShadows(Sem(it.second));
-    }
-    if (auto* param = var->As<sem::Parameter>()) {
-      param->SetShadows(Sem(it.second));
-    }
+    Switch(
+        Sem(it.first),  //
+        [&](sem::LocalVariable* local) { local->SetShadows(Sem(it.second)); },
+        [&](sem::Parameter* param) { param->SetShadows(Sem(it.second)); });
   }
 }  // namespace resolver
 
-bool Resolver::GlobalVariable(const ast::Variable* var) {
+sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) {
   auto* sem = Variable(var, VariableKind::kGlobal);
   if (!sem) {
-    return false;
+    return nullptr;
   }
 
   auto storage_class = sem->StorageClass();
   if (!var->is_const && storage_class == ast::StorageClass::kNone) {
     AddError("global variables must have a storage class", var->source);
-    return false;
+    return nullptr;
   }
   if (var->is_const && storage_class != ast::StorageClass::kNone) {
     AddError("global constants shouldn't have a storage class", var->source);
-    return false;
+    return nullptr;
   }
 
   for (auto* attr : var->attributes) {
@@ -558,20 +560,20 @@
   }
 
   if (!ValidateNoDuplicateAttributes(var->attributes)) {
-    return false;
+    return nullptr;
   }
 
   if (!ValidateGlobalVariable(sem)) {
-    return false;
+    return nullptr;
   }
 
   // TODO(bclayton): Call this at the end of resolve on all uniform and storage
   // referenced structs
   if (!ValidateStorageClassLayout(sem)) {
-    return false;
+    return nullptr;
   }
 
-  return true;
+  return sem->As<sem::GlobalVariable>();
 }
 
 sem::Function* Resolver::Function(const ast::Function* decl) {
@@ -858,66 +860,71 @@
 }
 
 sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
-  if (stmt->Is<ast::CaseStatement>()) {
-    AddError("case statement can only be used inside a switch statement",
-             stmt->source);
-    return nullptr;
-  }
-  if (stmt->Is<ast::ElseStatement>()) {
-    TINT_ICE(Resolver, diagnostics_)
-        << "Resolver::Statement() encountered an Else statement. Else "
-           "statements are embedded in If statements, so should never be "
-           "encountered as top-level statements";
-    return nullptr;
-  }
+  return Switch(
+      stmt,
+      // Compound statements. These create their own sem::CompoundStatement
+      // bindings.
+      [&](const ast::BlockStatement* b) -> sem::Statement* {
+        return BlockStatement(b);
+      },
+      [&](const ast::ForLoopStatement* l) -> sem::Statement* {
+        return ForLoopStatement(l);
+      },
+      [&](const ast::LoopStatement* l) -> sem::Statement* {
+        return LoopStatement(l);
+      },
+      [&](const ast::IfStatement* i) -> sem::Statement* {
+        return IfStatement(i);
+      },
+      [&](const ast::SwitchStatement* s) -> sem::Statement* {
+        return SwitchStatement(s);
+      },
 
-  // Compound statements. These create their own sem::CompoundStatement
-  // bindings.
-  if (auto* b = stmt->As<ast::BlockStatement>()) {
-    return BlockStatement(b);
-  }
-  if (auto* l = stmt->As<ast::ForLoopStatement>()) {
-    return ForLoopStatement(l);
-  }
-  if (auto* l = stmt->As<ast::LoopStatement>()) {
-    return LoopStatement(l);
-  }
-  if (auto* i = stmt->As<ast::IfStatement>()) {
-    return IfStatement(i);
-  }
-  if (auto* s = stmt->As<ast::SwitchStatement>()) {
-    return SwitchStatement(s);
-  }
+      // Non-Compound statements
+      [&](const ast::AssignmentStatement* a) -> sem::Statement* {
+        return AssignmentStatement(a);
+      },
+      [&](const ast::BreakStatement* b) -> sem::Statement* {
+        return BreakStatement(b);
+      },
+      [&](const ast::CallStatement* c) -> sem::Statement* {
+        return CallStatement(c);
+      },
+      [&](const ast::ContinueStatement* c) -> sem::Statement* {
+        return ContinueStatement(c);
+      },
+      [&](const ast::DiscardStatement* d) -> sem::Statement* {
+        return DiscardStatement(d);
+      },
+      [&](const ast::FallthroughStatement* f) -> sem::Statement* {
+        return FallthroughStatement(f);
+      },
+      [&](const ast::ReturnStatement* r) -> sem::Statement* {
+        return ReturnStatement(r);
+      },
+      [&](const ast::VariableDeclStatement* v) -> sem::Statement* {
+        return VariableDeclStatement(v);
+      },
 
-  // Non-Compound statements
-  if (auto* a = stmt->As<ast::AssignmentStatement>()) {
-    return AssignmentStatement(a);
-  }
-  if (auto* b = stmt->As<ast::BreakStatement>()) {
-    return BreakStatement(b);
-  }
-  if (auto* c = stmt->As<ast::CallStatement>()) {
-    return CallStatement(c);
-  }
-  if (auto* c = stmt->As<ast::ContinueStatement>()) {
-    return ContinueStatement(c);
-  }
-  if (auto* d = stmt->As<ast::DiscardStatement>()) {
-    return DiscardStatement(d);
-  }
-  if (auto* f = stmt->As<ast::FallthroughStatement>()) {
-    return FallthroughStatement(f);
-  }
-  if (auto* r = stmt->As<ast::ReturnStatement>()) {
-    return ReturnStatement(r);
-  }
-  if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
-    return VariableDeclStatement(v);
-  }
-
-  AddError("unknown statement type: " + std::string(stmt->TypeInfo().name),
-           stmt->source);
-  return nullptr;
+      // Error cases
+      [&](const ast::CaseStatement*) -> sem::Statement* {
+        AddError("case statement can only be used inside a switch statement",
+                 stmt->source);
+        return nullptr;
+      },
+      [&](const ast::ElseStatement*) -> sem::Statement* {
+        TINT_ICE(Resolver, diagnostics_)
+            << "Resolver::Statement() encountered an Else statement. Else "
+               "statements are embedded in If statements, so should never be "
+               "encountered as top-level statements";
+        return nullptr;
+      },
+      [&](Default) -> sem::Statement* {
+        AddError(
+            "unknown statement type: " + std::string(stmt->TypeInfo().name),
+            stmt->source);
+        return nullptr;
+      });
 }
 
 sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt) {
@@ -1137,32 +1144,42 @@
   }
 
   for (auto* expr : utils::Reverse(sorted)) {
-    sem::Expression* sem_expr = nullptr;
-    if (auto* array = expr->As<ast::IndexAccessorExpression>()) {
-      sem_expr = IndexAccessor(array);
-    } else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
-      sem_expr = Binary(bin_op);
-    } else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
-      sem_expr = Bitcast(bitcast);
-    } else if (auto* call = expr->As<ast::CallExpression>()) {
-      sem_expr = Call(call);
-    } else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
-      sem_expr = Identifier(ident);
-    } else if (auto* literal = expr->As<ast::LiteralExpression>()) {
-      sem_expr = Literal(literal);
-    } else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
-      sem_expr = MemberAccessor(member);
-    } else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
-      sem_expr = UnaryOp(unary);
-    } else if (expr->Is<ast::PhonyExpression>()) {
-      sem_expr = builder_->create<sem::Expression>(
-          expr, builder_->create<sem::Void>(), current_statement_,
-          sem::Constant{}, /* has_side_effects */ false);
-    } else {
-      TINT_ICE(Resolver, diagnostics_)
-          << "unhandled expression type: " << expr->TypeInfo().name;
-      return nullptr;
-    }
+    auto* sem_expr = Switch(
+        expr,
+        [&](const ast::IndexAccessorExpression* array) -> sem::Expression* {
+          return IndexAccessor(array);
+        },
+        [&](const ast::BinaryExpression* bin_op) -> sem::Expression* {
+          return Binary(bin_op);
+        },
+        [&](const ast::BitcastExpression* bitcast) -> sem::Expression* {
+          return Bitcast(bitcast);
+        },
+        [&](const ast::CallExpression* call) -> sem::Expression* {
+          return Call(call);
+        },
+        [&](const ast::IdentifierExpression* ident) -> sem::Expression* {
+          return Identifier(ident);
+        },
+        [&](const ast::LiteralExpression* literal) -> sem::Expression* {
+          return Literal(literal);
+        },
+        [&](const ast::MemberAccessorExpression* member) -> sem::Expression* {
+          return MemberAccessor(member);
+        },
+        [&](const ast::UnaryOpExpression* unary) -> sem::Expression* {
+          return UnaryOp(unary);
+        },
+        [&](const ast::PhonyExpression*) -> sem::Expression* {
+          return builder_->create<sem::Expression>(
+              expr, builder_->create<sem::Void>(), current_statement_,
+              sem::Constant{}, /* has_side_effects */ false);
+        },
+        [&](Default) {
+          TINT_ICE(Resolver, diagnostics_)
+              << "unhandled expression type: " << expr->TypeInfo().name;
+          return nullptr;
+        });
     if (!sem_expr) {
       return nullptr;
     }
@@ -1183,15 +1200,23 @@
   auto* obj = Sem(expr->object);
   auto* obj_raw_ty = obj->Type();
   auto* obj_ty = obj_raw_ty->UnwrapRef();
-  const sem::Type* ty = nullptr;
-  if (auto* arr = obj_ty->As<sem::Array>()) {
-    ty = arr->ElemType();
-  } else if (auto* vec = obj_ty->As<sem::Vector>()) {
-    ty = vec->type();
-  } else if (auto* mat = obj_ty->As<sem::Matrix>()) {
-    ty = builder_->create<sem::Vector>(mat->type(), mat->rows());
-  } else {
-    AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", expr->source);
+  auto* ty = Switch(
+      obj_ty,  //
+      [&](const sem::Array* arr) -> const sem::Type* {
+        return arr->ElemType();
+      },
+      [&](const sem::Vector* vec) -> const sem::Type* {  //
+        return vec->type();
+      },
+      [&](const sem::Matrix* mat) -> const sem::Type* {
+        return builder_->create<sem::Vector>(mat->type(), mat->rows());
+      },
+      [&](Default) -> const sem::Type* {
+        AddError("cannot index type '" + TypeNameOf(obj_ty) + "'",
+                 expr->source);
+        return nullptr;
+      });
+  if (ty == nullptr) {
     return nullptr;
   }
 
@@ -1528,24 +1553,30 @@
         // Now that the argument types have been determined, make sure that
         // they obey the conversion rules laid out in
         // https://gpuweb.github.io/gpuweb/wgsl/#conversion-expr.
-        bool ok = true;
-        if (auto* vec_type = target->As<sem::Vector>()) {
-          ok = ValidateVectorConstructorOrCast(expr, vec_type);
-        } else if (auto* mat_type = target->As<sem::Matrix>()) {
-          // Note: Matrix types currently cannot be converted (the element
-          // type must only be f32). We implement this for the day we support
-          // other matrix element types.
-          ok = ValidateMatrixConstructorOrCast(expr, mat_type);
-        } else if (target->is_scalar()) {
-          ok = ValidateScalarConstructorOrCast(expr, target);
-        } else if (auto* arr_type = target->As<sem::Array>()) {
-          ok = ValidateArrayConstructorOrCast(expr, arr_type);
-        } else if (auto* struct_type = target->As<sem::Struct>()) {
-          ok = ValidateStructureConstructorOrCast(expr, struct_type);
-        } else {
-          AddError("type is not constructible", expr->source);
-          return nullptr;
-        }
+        bool ok = Switch(
+            target,
+            [&](const sem::Vector* vec_type) {
+              return ValidateVectorConstructorOrCast(expr, vec_type);
+            },
+            [&](const sem::Matrix* mat_type) {
+              // Note: Matrix types currently cannot be converted (the element
+              // type must only be f32). We implement this for the day we
+              // support other matrix element types.
+              return ValidateMatrixConstructorOrCast(expr, mat_type);
+            },
+            [&](const sem::Array* arr_type) {
+              return ValidateArrayConstructorOrCast(expr, arr_type);
+            },
+            [&](const sem::Struct* struct_type) {
+              return ValidateStructureConstructorOrCast(expr, struct_type);
+            },
+            [&](Default) {
+              if (target->is_scalar()) {
+                return ValidateScalarConstructorOrCast(expr, target);
+              }
+              AddError("type is not constructible", expr->source);
+              return false;
+            });
         if (!ok) {
           return nullptr;
         }
@@ -1588,21 +1619,27 @@
         // Now that the argument types have been determined, make sure that
         // they obey the constructor type rules laid out in
         // https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr.
-        bool ok = true;
-        if (auto* vec_type = ty->As<sem::Vector>()) {
-          ok = ValidateVectorConstructorOrCast(expr, vec_type);
-        } else if (auto* mat_type = ty->As<sem::Matrix>()) {
-          ok = ValidateMatrixConstructorOrCast(expr, mat_type);
-        } else if (ty->is_scalar()) {
-          ok = ValidateScalarConstructorOrCast(expr, ty);
-        } else if (auto* arr_type = ty->As<sem::Array>()) {
-          ok = ValidateArrayConstructorOrCast(expr, arr_type);
-        } else if (auto* struct_type = ty->As<sem::Struct>()) {
-          ok = ValidateStructureConstructorOrCast(expr, struct_type);
-        } else {
-          AddError("type is not constructible", expr->source);
-          return nullptr;
-        }
+        bool ok = Switch(
+            ty,
+            [&](const sem::Vector* vec_type) {
+              return ValidateVectorConstructorOrCast(expr, vec_type);
+            },
+            [&](const sem::Matrix* mat_type) {
+              return ValidateMatrixConstructorOrCast(expr, mat_type);
+            },
+            [&](const sem::Array* arr_type) {
+              return ValidateArrayConstructorOrCast(expr, arr_type);
+            },
+            [&](const sem::Struct* struct_type) {
+              return ValidateStructureConstructorOrCast(expr, struct_type);
+            },
+            [&](Default) {
+              if (ty->is_scalar()) {
+                return ValidateScalarConstructorOrCast(expr, ty);
+              }
+              AddError("type is not constructible", expr->source);
+              return false;
+            });
         if (!ok) {
           return nullptr;
         }
@@ -2155,21 +2192,25 @@
 }
 
 sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
-  if (lit->Is<ast::SintLiteralExpression>()) {
-    return builder_->create<sem::I32>();
-  }
-  if (lit->Is<ast::UintLiteralExpression>()) {
-    return builder_->create<sem::U32>();
-  }
-  if (lit->Is<ast::FloatLiteralExpression>()) {
-    return builder_->create<sem::F32>();
-  }
-  if (lit->Is<ast::BoolLiteralExpression>()) {
-    return builder_->create<sem::Bool>();
-  }
-  TINT_UNREACHABLE(Resolver, diagnostics_)
-      << "Unhandled literal type: " << lit->TypeInfo().name;
-  return nullptr;
+  return Switch(
+      lit,
+      [&](const ast::SintLiteralExpression*) -> sem::Type* {
+        return builder_->create<sem::I32>();
+      },
+      [&](const ast::UintLiteralExpression*) -> sem::Type* {
+        return builder_->create<sem::U32>();
+      },
+      [&](const ast::FloatLiteralExpression*) -> sem::Type* {
+        return builder_->create<sem::F32>();
+      },
+      [&](const ast::BoolLiteralExpression*) -> sem::Type* {
+        return builder_->create<sem::Bool>();
+      },
+      [&](Default) -> sem::Type* {
+        TINT_UNREACHABLE(Resolver, diagnostics_)
+            << "Unhandled literal type: " << lit->TypeInfo().name;
+        return nullptr;
+      });
 }
 
 sem::Array* Resolver::Array(const ast::Array* arr) {
@@ -2770,30 +2811,23 @@
 
 // https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types
 bool Resolver::IsFixedFootprint(const sem::Type* type) const {
-  if (type->is_scalar()) {
-    return true;
-  }
-  if (type->Is<sem::Vector>()) {
-    return true;
-  }
-  if (type->Is<sem::Matrix>()) {
-    return true;
-  }
-  if (type->Is<sem::Atomic>()) {
-    return true;
-  }
-  if (auto* arr = type->As<sem::Array>()) {
-    return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType());
-  }
-  if (auto* str = type->As<sem::Struct>()) {
-    for (auto* member : str->Members()) {
-      if (!IsFixedFootprint(member->Type())) {
-        return false;
-      }
-    }
-    return true;
-  }
-  return false;
+  return Switch(
+      type,                                      //
+      [&](const sem::Vector*) { return true; },  //
+      [&](const sem::Matrix*) { return true; },  //
+      [&](const sem::Atomic*) { return true; },
+      [&](const sem::Array* arr) {
+        return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType());
+      },
+      [&](const sem::Struct* str) {
+        for (auto* member : str->Members()) {
+          if (!IsFixedFootprint(member->Type())) {
+            return false;
+          }
+        }
+        return true;
+      },
+      [&](Default) { return type->is_scalar(); });
 }
 
 // https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
@@ -2806,27 +2840,22 @@
   if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
     return true;
   }
-  if (auto* vec = type->As<sem::Vector>()) {
-    return IsHostShareable(vec->type());
-  }
-  if (auto* mat = type->As<sem::Matrix>()) {
-    return IsHostShareable(mat->type());
-  }
-  if (auto* arr = type->As<sem::Array>()) {
-    return IsHostShareable(arr->ElemType());
-  }
-  if (auto* str = type->As<sem::Struct>()) {
-    for (auto* member : str->Members()) {
-      if (!IsHostShareable(member->Type())) {
-        return false;
-      }
-    }
-    return true;
-  }
-  if (auto* atomic = type->As<sem::Atomic>()) {
-    return IsHostShareable(atomic->Type());
-  }
-  return false;
+  return Switch(
+      type,  //
+      [&](const sem::Vector* vec) { return IsHostShareable(vec->type()); },
+      [&](const sem::Matrix* mat) { return IsHostShareable(mat->type()); },
+      [&](const sem::Array* arr) { return IsHostShareable(arr->ElemType()); },
+      [&](const sem::Struct* str) {
+        for (auto* member : str->Members()) {
+          if (!IsHostShareable(member->Type())) {
+            return false;
+          }
+        }
+        return true;
+      },
+      [&](const sem::Atomic* atomic) {
+        return IsHostShareable(atomic->Type());
+      });
 }
 
 bool Resolver::IsBuiltin(Symbol symbol) const {
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 33d945d..ed7ab3a 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -219,6 +219,7 @@
   sem::ElseStatement* ElseStatement(const ast::ElseStatement*);
   sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
   sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
+  sem::GlobalVariable* GlobalVariable(const ast::Variable*);
   sem::Statement* Parameter(const ast::Variable*);
   sem::IfStatement* IfStatement(const ast::IfStatement*);
   sem::LoopStatement* LoopStatement(const ast::LoopStatement*);
@@ -228,8 +229,6 @@
   sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
   bool Statements(const ast::StatementList&);
 
-  bool GlobalVariable(const ast::Variable*);
-
   // AST and Type validation methods
   // Each return true on success, false on failure.
   bool ValidateAlias(const ast::Alias*);