resolver: Optimize type dispatch with Switch()

Bug: tint:1383
Change-Id: Ia02c7ddd3e46d36134f5430e4f22df04993b2158
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/81104
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Ben Clayton <bclayton@chromium.org>
diff --git a/src/resolver/dependency_graph.cc b/src/resolver/dependency_graph.cc
index cbb574d..bf82f2d 100644
--- a/src/resolver/dependency_graph.cc
+++ b/src/resolver/dependency_graph.cc
@@ -139,34 +139,31 @@
   /// dependencies of each global.
   void Scan(Global* global) {
     TINT_SCOPED_ASSIGNMENT(current_global_, global);
-
-    if (auto* str = global->node->As<ast::Struct>()) {
-      Declare(str->name, str);
-      for (auto* member : str->members) {
-        TraverseType(member->type);
-      }
-      return;
-    }
-    if (auto* alias = global->node->As<ast::Alias>()) {
-      Declare(alias->name, alias);
-      TraverseType(alias->type);
-      return;
-    }
-    if (auto* func = global->node->As<ast::Function>()) {
-      Declare(func->symbol, func);
-      TraverseAttributes(func->attributes);
-      TraverseFunction(func);
-      return;
-    }
-    if (auto* var = global->node->As<ast::Variable>()) {
-      Declare(var->symbol, var);
-      TraverseType(var->type);
-      if (var->constructor) {
-        TraverseExpression(var->constructor);
-      }
-      return;
-    }
-    UnhandledNode(diagnostics_, global->node);
+    Switch(
+        global->node,
+        [&](const ast::Struct* str) {
+          Declare(str->name, str);
+          for (auto* member : str->members) {
+            TraverseType(member->type);
+          }
+        },
+        [&](const ast::Alias* alias) {
+          Declare(alias->name, alias);
+          TraverseType(alias->type);
+        },
+        [&](const ast::Function* func) {
+          Declare(func->symbol, func);
+          TraverseAttributes(func->attributes);
+          TraverseFunction(func);
+        },
+        [&](const ast::Variable* var) {
+          Declare(var->symbol, var);
+          TraverseType(var->type);
+          if (var->constructor) {
+            TraverseExpression(var->constructor);
+          }
+        },
+        [&](Default) { UnhandledNode(diagnostics_, global->node); });
   }
 
  private:
@@ -208,78 +205,72 @@
   /// Traverses the statement, performing symbol resolution and determining
   /// global dependencies.
   void TraverseStatement(const ast::Statement* stmt) {
-    if (stmt == nullptr) {
+    if (!stmt) {
       return;
     }
-    if (auto* b = stmt->As<ast::AssignmentStatement>()) {
-      TraverseExpression(b->lhs);
-      TraverseExpression(b->rhs);
-      return;
-    }
-    if (auto* b = stmt->As<ast::BlockStatement>()) {
-      scope_stack_.Push();
-      TINT_DEFER(scope_stack_.Pop());
-      TraverseStatements(b->statements);
-      return;
-    }
-    if (auto* r = stmt->As<ast::CallStatement>()) {
-      TraverseExpression(r->expr);
-      return;
-    }
-    if (auto* l = stmt->As<ast::ForLoopStatement>()) {
-      scope_stack_.Push();
-      TINT_DEFER(scope_stack_.Pop());
-      TraverseStatement(l->initializer);
-      TraverseExpression(l->condition);
-      TraverseStatement(l->continuing);
-      TraverseStatement(l->body);
-      return;
-    }
-    if (auto* l = stmt->As<ast::LoopStatement>()) {
-      scope_stack_.Push();
-      TINT_DEFER(scope_stack_.Pop());
-      TraverseStatements(l->body->statements);
-      TraverseStatement(l->continuing);
-      return;
-    }
-    if (auto* i = stmt->As<ast::IfStatement>()) {
-      TraverseExpression(i->condition);
-      TraverseStatement(i->body);
-      for (auto* e : i->else_statements) {
-        TraverseExpression(e->condition);
-        TraverseStatement(e->body);
-      }
-      return;
-    }
-    if (auto* r = stmt->As<ast::ReturnStatement>()) {
-      TraverseExpression(r->value);
-      return;
-    }
-    if (auto* s = stmt->As<ast::SwitchStatement>()) {
-      TraverseExpression(s->condition);
-      for (auto* c : s->body) {
-        for (auto* sel : c->selectors) {
-          TraverseExpression(sel);
-        }
-        TraverseStatement(c->body);
-      }
-      return;
-    }
-    if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
-      if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
-        graph_.shadows.emplace(v->variable, shadows);
-      }
-      TraverseType(v->variable->type);
-      TraverseExpression(v->variable->constructor);
-      Declare(v->variable->symbol, v->variable);
-      return;
-    }
-    if (stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
-                      ast::DiscardStatement, ast::FallthroughStatement>()) {
-      return;
-    }
-
-    UnhandledNode(diagnostics_, stmt);
+    Switch(
+        stmt,  //
+        [&](const ast::AssignmentStatement* a) {
+          TraverseExpression(a->lhs);
+          TraverseExpression(a->rhs);
+        },
+        [&](const ast::BlockStatement* b) {
+          scope_stack_.Push();
+          TINT_DEFER(scope_stack_.Pop());
+          TraverseStatements(b->statements);
+        },
+        [&](const ast::CallStatement* r) {  //
+          TraverseExpression(r->expr);
+        },
+        [&](const ast::ForLoopStatement* l) {
+          scope_stack_.Push();
+          TINT_DEFER(scope_stack_.Pop());
+          TraverseStatement(l->initializer);
+          TraverseExpression(l->condition);
+          TraverseStatement(l->continuing);
+          TraverseStatement(l->body);
+        },
+        [&](const ast::LoopStatement* l) {
+          scope_stack_.Push();
+          TINT_DEFER(scope_stack_.Pop());
+          TraverseStatements(l->body->statements);
+          TraverseStatement(l->continuing);
+        },
+        [&](const ast::IfStatement* i) {
+          TraverseExpression(i->condition);
+          TraverseStatement(i->body);
+          for (auto* e : i->else_statements) {
+            TraverseExpression(e->condition);
+            TraverseStatement(e->body);
+          }
+        },
+        [&](const ast::ReturnStatement* r) {  //
+          TraverseExpression(r->value);
+        },
+        [&](const ast::SwitchStatement* s) {
+          TraverseExpression(s->condition);
+          for (auto* c : s->body) {
+            for (auto* sel : c->selectors) {
+              TraverseExpression(sel);
+            }
+            TraverseStatement(c->body);
+          }
+        },
+        [&](const ast::VariableDeclStatement* v) {
+          if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
+            graph_.shadows.emplace(v->variable, shadows);
+          }
+          TraverseType(v->variable->type);
+          TraverseExpression(v->variable->constructor);
+          Declare(v->variable->symbol, v->variable);
+        },
+        [&](Default) {
+          if (!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
+                             ast::DiscardStatement,
+                             ast::FallthroughStatement>()) {
+            UnhandledNode(diagnostics_, stmt);
+          }
+        });
   }
 
   /// Adds the symbol definition to the current scope, raising an error if two
@@ -302,21 +293,23 @@
     }
     ast::TraverseExpressions(
         root, diagnostics_, [&](const ast::Expression* expr) {
-          if (auto* ident = expr->As<ast::IdentifierExpression>()) {
-            AddDependency(ident, ident->symbol, "identifier", "references");
-          }
-          if (auto* call = expr->As<ast::CallExpression>()) {
-            if (call->target.name) {
-              AddDependency(call->target.name, call->target.name->symbol,
-                            "function", "calls");
-            }
-            if (call->target.type) {
-              TraverseType(call->target.type);
-            }
-          }
-          if (auto* cast = expr->As<ast::BitcastExpression>()) {
-            TraverseType(cast->type);
-          }
+          Switch(
+              expr,
+              [&](const ast::IdentifierExpression* ident) {
+                AddDependency(ident, ident->symbol, "identifier", "references");
+              },
+              [&](const ast::CallExpression* call) {
+                if (call->target.name) {
+                  AddDependency(call->target.name, call->target.name->symbol,
+                                "function", "calls");
+                }
+                if (call->target.type) {
+                  TraverseType(call->target.type);
+                }
+              },
+              [&](const ast::BitcastExpression* cast) {
+                TraverseType(cast->type);
+              });
           return ast::TraverseAction::Descend;
         });
   }
@@ -324,50 +317,44 @@
   /// Traverses the type node, performing symbol resolution and determining
   /// global dependencies.
   void TraverseType(const ast::Type* ty) {
-    if (ty == nullptr) {
+    if (!ty) {
       return;
     }
-    if (auto* arr = ty->As<ast::Array>()) {
-      TraverseType(arr->type);
-      TraverseExpression(arr->count);
-      return;
-    }
-    if (auto* atomic = ty->As<ast::Atomic>()) {
-      TraverseType(atomic->type);
-      return;
-    }
-    if (auto* mat = ty->As<ast::Matrix>()) {
-      TraverseType(mat->type);
-      return;
-    }
-    if (auto* ptr = ty->As<ast::Pointer>()) {
-      TraverseType(ptr->type);
-      return;
-    }
-    if (auto* tn = ty->As<ast::TypeName>()) {
-      AddDependency(tn, tn->name, "type", "references");
-      return;
-    }
-    if (auto* vec = ty->As<ast::Vector>()) {
-      TraverseType(vec->type);
-      return;
-    }
-    if (auto* tex = ty->As<ast::SampledTexture>()) {
-      TraverseType(tex->type);
-      return;
-    }
-    if (auto* tex = ty->As<ast::MultisampledTexture>()) {
-      TraverseType(tex->type);
-      return;
-    }
-    if (ty->IsAnyOf<ast::Void, ast::Bool, ast::I32, ast::U32, ast::F32,
-                    ast::DepthTexture, ast::DepthMultisampledTexture,
-                    ast::StorageTexture, ast::ExternalTexture,
-                    ast::Sampler>()) {
-      return;
-    }
-
-    UnhandledNode(diagnostics_, ty);
+    Switch(
+        ty,  //
+        [&](const ast::Array* arr) {
+          TraverseType(arr->type);  //
+          TraverseExpression(arr->count);
+        },
+        [&](const ast::Atomic* atomic) {  //
+          TraverseType(atomic->type);
+        },
+        [&](const ast::Matrix* mat) {  //
+          TraverseType(mat->type);
+        },
+        [&](const ast::Pointer* ptr) {  //
+          TraverseType(ptr->type);
+        },
+        [&](const ast::TypeName* tn) {  //
+          AddDependency(tn, tn->name, "type", "references");
+        },
+        [&](const ast::Vector* vec) {  //
+          TraverseType(vec->type);
+        },
+        [&](const ast::SampledTexture* tex) {  //
+          TraverseType(tex->type);
+        },
+        [&](const ast::MultisampledTexture* tex) {  //
+          TraverseType(tex->type);
+        },
+        [&](Default) {
+          if (!ty->IsAnyOf<ast::Void, ast::Bool, ast::I32, ast::U32, ast::F32,
+                           ast::DepthTexture, ast::DepthMultisampledTexture,
+                           ast::StorageTexture, ast::ExternalTexture,
+                           ast::Sampler>()) {
+            UnhandledNode(diagnostics_, ty);
+          }
+        });
   }
 
   /// Traverses the attribute list, performing symbol resolution and
@@ -490,17 +477,15 @@
   /// @note will raise an ICE if the node is not a type, function or variable
   /// declaration
   Symbol SymbolOf(const ast::Node* node) const {
-    if (auto* td = node->As<ast::TypeDecl>()) {
-      return td->name;
-    }
-    if (auto* func = node->As<ast::Function>()) {
-      return func->symbol;
-    }
-    if (auto* var = node->As<ast::Variable>()) {
-      return var->symbol;
-    }
-    UnhandledNode(diagnostics_, node);
-    return {};
+    return Switch(
+        node,  //
+        [&](const ast::TypeDecl* td) { return td->name; },
+        [&](const ast::Function* func) { return func->symbol; },
+        [&](const ast::Variable* var) { return var->symbol; },
+        [&](Default) {
+          UnhandledNode(diagnostics_, node);
+          return Symbol{};
+        });
   }
 
   /// @param node the ast::Node of the global declaration
@@ -516,20 +501,16 @@
   /// @note will raise an ICE if the node is not a type, function or variable
   /// declaration
   std::string KindOf(const ast::Node* node) {
-    if (node->Is<ast::Struct>()) {
-      return "struct";
-    }
-    if (node->Is<ast::Alias>()) {
-      return "alias";
-    }
-    if (node->Is<ast::Function>()) {
-      return "function";
-    }
-    if (auto* var = node->As<ast::Variable>()) {
-      return var->is_const ? "let" : "var";
-    }
-    UnhandledNode(diagnostics_, node);
-    return {};
+    return Switch(
+        node,  //
+        [&](const ast::Struct*) { return "struct"; },
+        [&](const ast::Alias*) { return "alias"; },
+        [&](const ast::Function*) { return "function"; },
+        [&](const ast::Variable* var) { return var->is_const ? "let" : "var"; },
+        [&](Default) {
+          UnhandledNode(diagnostics_, node);
+          return "<error>";
+        });
   }
 
   /// Traverses `module`, collecting all the global declarations and populating