Use the new Switch() inferred types

Change-Id: I48ecd18957101631caa27480e7b1937a10791118
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/81106
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc
index fe32d35..9c16094 100644
--- a/src/tint/reader/spirv/function.cc
+++ b/src/tint/reader/spirv/function.cc
@@ -1125,7 +1125,7 @@
   // Recursively flatten matrices, arrays, and structures.
   return Switch(
       tip_type,
-      [&](const Matrix* matrix_type) -> bool {
+      [&](const Matrix* matrix_type) {
         index_prefix.push_back(0);
         const auto num_columns = static_cast<int>(matrix_type->columns);
         const Type* vec_ty = ty_.Vector(matrix_type->type, matrix_type->rows);
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 42f8cb1..d7c13a6 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -121,20 +121,14 @@
   for (auto* decl : dependencies_.ordered_globals) {
     Mark(decl);
     if (!Switch(
-            decl,                           //
-            [&](const ast::TypeDecl* td) {  //
-              return TypeDecl(td) != nullptr;
-            },
-            [&](const ast::Function* func) {
-              return Function(func) != nullptr;
-            },
-            [&](const ast::Variable* var) {
-              return GlobalVariable(var) != nullptr;
-            },
+            decl,  //
+            [&](const ast::TypeDecl* td) { return TypeDecl(td); },
+            [&](const ast::Function* func) { return Function(func); },
+            [&](const ast::Variable* var) { return GlobalVariable(var); },
             [&](Default) {
               TINT_UNREACHABLE(Resolver, diagnostics_)
                   << "unhandled global declaration: " << decl->TypeInfo().name;
-              return false;
+              return nullptr;
             })) {
       return false;
     }
@@ -165,23 +159,13 @@
 sem::Type* Resolver::Type(const ast::Type* ty) {
   Mark(ty);
   auto* s = Switch(
-      ty,
-      [&](const ast::Void*) -> sem::Type* {
-        return builder_->create<sem::Void>();
-      },
-      [&](const ast::Bool*) -> sem::Type* {
-        return builder_->create<sem::Bool>();
-      },
-      [&](const ast::I32*) -> sem::Type* {
-        return builder_->create<sem::I32>();
-      },
-      [&](const ast::U32*) -> sem::Type* {
-        return builder_->create<sem::U32>();
-      },
-      [&](const ast::F32*) -> sem::Type* {
-        return builder_->create<sem::F32>();
-      },
-      [&](const ast::Vector* t) -> sem::Type* {
+      ty,  //
+      [&](const ast::Void*) { return builder_->create<sem::Void>(); },
+      [&](const ast::Bool*) { return builder_->create<sem::Bool>(); },
+      [&](const ast::I32*) { return builder_->create<sem::I32>(); },
+      [&](const ast::U32*) { return builder_->create<sem::U32>(); },
+      [&](const ast::F32*) { return builder_->create<sem::F32>(); },
+      [&](const ast::Vector* t) -> sem::Vector* {
         if (!t->type) {
           AddError("missing vector element type", t->source.End());
           return nullptr;
@@ -195,7 +179,7 @@
         }
         return nullptr;
       },
-      [&](const ast::Matrix* t) -> sem::Type* {
+      [&](const ast::Matrix* t) -> sem::Matrix* {
         if (!t->type) {
           AddError("missing matrix element type", t->source.End());
           return nullptr;
@@ -212,8 +196,8 @@
         }
         return nullptr;
       },
-      [&](const ast::Array* t) -> sem::Type* { return Array(t); },
-      [&](const ast::Atomic* t) -> sem::Type* {
+      [&](const ast::Array* t) { return Array(t); },
+      [&](const ast::Atomic* t) -> sem::Atomic* {
         if (auto* el = Type(t->type)) {
           auto* a = builder_->create<sem::Atomic>(el);
           if (!ValidateAtomic(t, a)) {
@@ -223,7 +207,7 @@
         }
         return nullptr;
       },
-      [&](const ast::Pointer* t) -> sem::Type* {
+      [&](const ast::Pointer* t) -> sem::Pointer* {
         if (auto* el = Type(t->type)) {
           auto access = t->access;
           if (access == ast::kUndefined) {
@@ -233,28 +217,28 @@
         }
         return nullptr;
       },
-      [&](const ast::Sampler* t) -> sem::Type* {
+      [&](const ast::Sampler* t) {
         return builder_->create<sem::Sampler>(t->kind);
       },
-      [&](const ast::SampledTexture* t) -> sem::Type* {
+      [&](const ast::SampledTexture* t) -> sem::SampledTexture* {
         if (auto* el = Type(t->type)) {
           return builder_->create<sem::SampledTexture>(t->dim, el);
         }
         return nullptr;
       },
-      [&](const ast::MultisampledTexture* t) -> sem::Type* {
+      [&](const ast::MultisampledTexture* t) -> sem::MultisampledTexture* {
         if (auto* el = Type(t->type)) {
           return builder_->create<sem::MultisampledTexture>(t->dim, el);
         }
         return nullptr;
       },
-      [&](const ast::DepthTexture* t) -> sem::Type* {
+      [&](const ast::DepthTexture* t) {
         return builder_->create<sem::DepthTexture>(t->dim);
       },
-      [&](const ast::DepthMultisampledTexture* t) -> sem::Type* {
+      [&](const ast::DepthMultisampledTexture* t) {
         return builder_->create<sem::DepthMultisampledTexture>(t->dim);
       },
-      [&](const ast::StorageTexture* t) -> sem::Type* {
+      [&](const ast::StorageTexture* t) -> sem::StorageTexture* {
         if (auto* el = Type(t->type)) {
           if (!ValidateStorageTexture(t)) {
             return nullptr;
@@ -264,10 +248,10 @@
         }
         return nullptr;
       },
-      [&](const ast::ExternalTexture*) -> sem::Type* {
+      [&](const ast::ExternalTexture*) {
         return builder_->create<sem::ExternalTexture>();
       },
-      [&](Default) -> sem::Type* {
+      [&](Default) {
         auto* resolved = ResolvedSymbol(ty);
         return Switch(
             resolved,  //
@@ -858,62 +842,40 @@
       stmt,
       // Compound statements. These create their own sem::CompoundStatement
       // bindings.
-      [&](const ast::BlockStatement* b) -> sem::Statement* {
-        return BlockStatement(b);
-      },
-      [&](const ast::ForLoopStatement* l) -> sem::Statement* {
-        return ForLoopStatement(l);
-      },
-      [&](const ast::LoopStatement* l) -> sem::Statement* {
-        return LoopStatement(l);
-      },
-      [&](const ast::IfStatement* i) -> sem::Statement* {
-        return IfStatement(i);
-      },
-      [&](const ast::SwitchStatement* s) -> sem::Statement* {
-        return SwitchStatement(s);
-      },
+      [&](const ast::BlockStatement* b) { return BlockStatement(b); },
+      [&](const ast::ForLoopStatement* l) { return ForLoopStatement(l); },
+      [&](const ast::LoopStatement* l) { return LoopStatement(l); },
+      [&](const ast::IfStatement* i) { return IfStatement(i); },
+      [&](const ast::SwitchStatement* s) { return SwitchStatement(s); },
 
       // Non-Compound statements
-      [&](const ast::AssignmentStatement* a) -> sem::Statement* {
-        return AssignmentStatement(a);
-      },
-      [&](const ast::BreakStatement* b) -> sem::Statement* {
-        return BreakStatement(b);
-      },
-      [&](const ast::CallStatement* c) -> sem::Statement* {
-        return CallStatement(c);
-      },
-      [&](const ast::ContinueStatement* c) -> sem::Statement* {
-        return ContinueStatement(c);
-      },
-      [&](const ast::DiscardStatement* d) -> sem::Statement* {
-        return DiscardStatement(d);
-      },
-      [&](const ast::FallthroughStatement* f) -> sem::Statement* {
+      [&](const ast::AssignmentStatement* a) { return AssignmentStatement(a); },
+      [&](const ast::BreakStatement* b) { return BreakStatement(b); },
+      [&](const ast::CallStatement* c) { return CallStatement(c); },
+      [&](const ast::ContinueStatement* c) { return ContinueStatement(c); },
+      [&](const ast::DiscardStatement* d) { return DiscardStatement(d); },
+      [&](const ast::FallthroughStatement* f) {
         return FallthroughStatement(f);
       },
-      [&](const ast::ReturnStatement* r) -> sem::Statement* {
-        return ReturnStatement(r);
-      },
-      [&](const ast::VariableDeclStatement* v) -> sem::Statement* {
+      [&](const ast::ReturnStatement* r) { return ReturnStatement(r); },
+      [&](const ast::VariableDeclStatement* v) {
         return VariableDeclStatement(v);
       },
 
       // Error cases
-      [&](const ast::CaseStatement*) -> sem::Statement* {
+      [&](const ast::CaseStatement*) {
         AddError("case statement can only be used inside a switch statement",
                  stmt->source);
         return nullptr;
       },
-      [&](const ast::ElseStatement*) -> sem::Statement* {
+      [&](const ast::ElseStatement*) {
         TINT_ICE(Resolver, diagnostics_)
             << "Resolver::Statement() encountered an Else statement. Else "
                "statements are embedded in If statements, so should never be "
                "encountered as top-level statements";
         return nullptr;
       },
-      [&](Default) -> sem::Statement* {
+      [&](Default) {
         AddError(
             "unknown statement type: " + std::string(stmt->TypeInfo().name),
             stmt->source);
@@ -1196,16 +1158,12 @@
   auto* obj_ty = obj_raw_ty->UnwrapRef();
   auto* ty = Switch(
       obj_ty,  //
-      [&](const sem::Array* arr) -> const sem::Type* {
-        return arr->ElemType();
-      },
-      [&](const sem::Vector* vec) -> const sem::Type* {  //
-        return vec->type();
-      },
-      [&](const sem::Matrix* mat) -> const sem::Type* {
+      [&](const sem::Array* arr) { return arr->ElemType(); },
+      [&](const sem::Vector* vec) { return vec->type(); },
+      [&](const sem::Matrix* mat) {
         return builder_->create<sem::Vector>(mat->type(), mat->rows());
       },
-      [&](Default) -> const sem::Type* {
+      [&](Default) {
         AddError("cannot index type '" + TypeNameOf(obj_ty) + "'",
                  expr->source);
         return nullptr;
@@ -2188,19 +2146,19 @@
 sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
   return Switch(
       lit,
-      [&](const ast::SintLiteralExpression*) -> sem::Type* {
+      [&](const ast::SintLiteralExpression*) {
         return builder_->create<sem::I32>();
       },
-      [&](const ast::UintLiteralExpression*) -> sem::Type* {
+      [&](const ast::UintLiteralExpression*) {
         return builder_->create<sem::U32>();
       },
-      [&](const ast::FloatLiteralExpression*) -> sem::Type* {
+      [&](const ast::FloatLiteralExpression*) {
         return builder_->create<sem::F32>();
       },
-      [&](const ast::BoolLiteralExpression*) -> sem::Type* {
+      [&](const ast::BoolLiteralExpression*) {
         return builder_->create<sem::Bool>();
       },
-      [&](Default) -> sem::Type* {
+      [&](Default) {
         TINT_UNREACHABLE(Resolver, diagnostics_)
             << "Unhandled literal type: " << lit->TypeInfo().name;
         return nullptr;
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 83f2666..0f0623e 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -575,31 +575,29 @@
 uint32_t Builder::GenerateExpression(const ast::Expression* expr) {
   return Switch(
       expr,
-      [&](const ast::IndexAccessorExpression* a) {  //
+      [&](const ast::IndexAccessorExpression* a) {
         return GenerateAccessorExpression(a);
       },
-      [&](const ast::BinaryExpression* b) {  //
+      [&](const ast::BinaryExpression* b) {
         return GenerateBinaryExpression(b);
       },
-      [&](const ast::BitcastExpression* b) {  //
+      [&](const ast::BitcastExpression* b) {
         return GenerateBitcastExpression(b);
       },
-      [&](const ast::CallExpression* c) {  //
-        return GenerateCallExpression(c);
-      },
-      [&](const ast::IdentifierExpression* i) {  //
+      [&](const ast::CallExpression* c) { return GenerateCallExpression(c); },
+      [&](const ast::IdentifierExpression* i) {
         return GenerateIdentifierExpression(i);
       },
-      [&](const ast::LiteralExpression* l) {  //
+      [&](const ast::LiteralExpression* l) {
         return GenerateLiteralIfNeeded(nullptr, l);
       },
-      [&](const ast::MemberAccessorExpression* m) {  //
+      [&](const ast::MemberAccessorExpression* m) {
         return GenerateAccessorExpression(m);
       },
-      [&](const ast::UnaryOpExpression* u) {  //
+      [&](const ast::UnaryOpExpression* u) {
         return GenerateUnaryOpExpression(u);
       },
-      [&](Default) -> uint32_t {
+      [&](Default) {
         error_ =
             "unknown expression type: " + std::string(expr->TypeInfo().name);
         return 0;
@@ -2271,7 +2269,7 @@
       [&](const sem::TypeConstructor*) {
         return GenerateTypeConstructorOrConversion(call, nullptr);
       },
-      [&](Default) -> uint32_t {
+      [&](Default) {
         TINT_ICE(Writer, builder_.Diagnostics())
             << "unhandled call target: " << target->TypeInfo().name;
         return 0;
@@ -4101,7 +4099,7 @@
       [&](const sem::StorageTexture* t) {
         return GenerateTypeIfNeeded(t->type());
       },
-      [&](Default) -> uint32_t {  //
+      [&](Default)  {
         return 0u;
       });
   if (type_id == 0u) {