resolver: Optimize type dispatch with Switch()
Bug: tint:1383
Change-Id: Ia02c7ddd3e46d36134f5430e4f22df04993b2158
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/81104
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Ben Clayton <bclayton@chromium.org>
diff --git a/src/resolver/dependency_graph.cc b/src/resolver/dependency_graph.cc
index cbb574d..bf82f2d 100644
--- a/src/resolver/dependency_graph.cc
+++ b/src/resolver/dependency_graph.cc
@@ -139,34 +139,31 @@
/// dependencies of each global.
void Scan(Global* global) {
TINT_SCOPED_ASSIGNMENT(current_global_, global);
-
- if (auto* str = global->node->As<ast::Struct>()) {
- Declare(str->name, str);
- for (auto* member : str->members) {
- TraverseType(member->type);
- }
- return;
- }
- if (auto* alias = global->node->As<ast::Alias>()) {
- Declare(alias->name, alias);
- TraverseType(alias->type);
- return;
- }
- if (auto* func = global->node->As<ast::Function>()) {
- Declare(func->symbol, func);
- TraverseAttributes(func->attributes);
- TraverseFunction(func);
- return;
- }
- if (auto* var = global->node->As<ast::Variable>()) {
- Declare(var->symbol, var);
- TraverseType(var->type);
- if (var->constructor) {
- TraverseExpression(var->constructor);
- }
- return;
- }
- UnhandledNode(diagnostics_, global->node);
+ Switch(
+ global->node,
+ [&](const ast::Struct* str) {
+ Declare(str->name, str);
+ for (auto* member : str->members) {
+ TraverseType(member->type);
+ }
+ },
+ [&](const ast::Alias* alias) {
+ Declare(alias->name, alias);
+ TraverseType(alias->type);
+ },
+ [&](const ast::Function* func) {
+ Declare(func->symbol, func);
+ TraverseAttributes(func->attributes);
+ TraverseFunction(func);
+ },
+ [&](const ast::Variable* var) {
+ Declare(var->symbol, var);
+ TraverseType(var->type);
+ if (var->constructor) {
+ TraverseExpression(var->constructor);
+ }
+ },
+ [&](Default) { UnhandledNode(diagnostics_, global->node); });
}
private:
@@ -208,78 +205,72 @@
/// Traverses the statement, performing symbol resolution and determining
/// global dependencies.
void TraverseStatement(const ast::Statement* stmt) {
- if (stmt == nullptr) {
+ if (!stmt) {
return;
}
- if (auto* b = stmt->As<ast::AssignmentStatement>()) {
- TraverseExpression(b->lhs);
- TraverseExpression(b->rhs);
- return;
- }
- if (auto* b = stmt->As<ast::BlockStatement>()) {
- scope_stack_.Push();
- TINT_DEFER(scope_stack_.Pop());
- TraverseStatements(b->statements);
- return;
- }
- if (auto* r = stmt->As<ast::CallStatement>()) {
- TraverseExpression(r->expr);
- return;
- }
- if (auto* l = stmt->As<ast::ForLoopStatement>()) {
- scope_stack_.Push();
- TINT_DEFER(scope_stack_.Pop());
- TraverseStatement(l->initializer);
- TraverseExpression(l->condition);
- TraverseStatement(l->continuing);
- TraverseStatement(l->body);
- return;
- }
- if (auto* l = stmt->As<ast::LoopStatement>()) {
- scope_stack_.Push();
- TINT_DEFER(scope_stack_.Pop());
- TraverseStatements(l->body->statements);
- TraverseStatement(l->continuing);
- return;
- }
- if (auto* i = stmt->As<ast::IfStatement>()) {
- TraverseExpression(i->condition);
- TraverseStatement(i->body);
- for (auto* e : i->else_statements) {
- TraverseExpression(e->condition);
- TraverseStatement(e->body);
- }
- return;
- }
- if (auto* r = stmt->As<ast::ReturnStatement>()) {
- TraverseExpression(r->value);
- return;
- }
- if (auto* s = stmt->As<ast::SwitchStatement>()) {
- TraverseExpression(s->condition);
- for (auto* c : s->body) {
- for (auto* sel : c->selectors) {
- TraverseExpression(sel);
- }
- TraverseStatement(c->body);
- }
- return;
- }
- if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
- if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
- graph_.shadows.emplace(v->variable, shadows);
- }
- TraverseType(v->variable->type);
- TraverseExpression(v->variable->constructor);
- Declare(v->variable->symbol, v->variable);
- return;
- }
- if (stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
- ast::DiscardStatement, ast::FallthroughStatement>()) {
- return;
- }
-
- UnhandledNode(diagnostics_, stmt);
+ Switch(
+ stmt, //
+ [&](const ast::AssignmentStatement* a) {
+ TraverseExpression(a->lhs);
+ TraverseExpression(a->rhs);
+ },
+ [&](const ast::BlockStatement* b) {
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+ TraverseStatements(b->statements);
+ },
+ [&](const ast::CallStatement* r) { //
+ TraverseExpression(r->expr);
+ },
+ [&](const ast::ForLoopStatement* l) {
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+ TraverseStatement(l->initializer);
+ TraverseExpression(l->condition);
+ TraverseStatement(l->continuing);
+ TraverseStatement(l->body);
+ },
+ [&](const ast::LoopStatement* l) {
+ scope_stack_.Push();
+ TINT_DEFER(scope_stack_.Pop());
+ TraverseStatements(l->body->statements);
+ TraverseStatement(l->continuing);
+ },
+ [&](const ast::IfStatement* i) {
+ TraverseExpression(i->condition);
+ TraverseStatement(i->body);
+ for (auto* e : i->else_statements) {
+ TraverseExpression(e->condition);
+ TraverseStatement(e->body);
+ }
+ },
+ [&](const ast::ReturnStatement* r) { //
+ TraverseExpression(r->value);
+ },
+ [&](const ast::SwitchStatement* s) {
+ TraverseExpression(s->condition);
+ for (auto* c : s->body) {
+ for (auto* sel : c->selectors) {
+ TraverseExpression(sel);
+ }
+ TraverseStatement(c->body);
+ }
+ },
+ [&](const ast::VariableDeclStatement* v) {
+ if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
+ graph_.shadows.emplace(v->variable, shadows);
+ }
+ TraverseType(v->variable->type);
+ TraverseExpression(v->variable->constructor);
+ Declare(v->variable->symbol, v->variable);
+ },
+ [&](Default) {
+ if (!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
+ ast::DiscardStatement,
+ ast::FallthroughStatement>()) {
+ UnhandledNode(diagnostics_, stmt);
+ }
+ });
}
/// Adds the symbol definition to the current scope, raising an error if two
@@ -302,21 +293,23 @@
}
ast::TraverseExpressions(
root, diagnostics_, [&](const ast::Expression* expr) {
- if (auto* ident = expr->As<ast::IdentifierExpression>()) {
- AddDependency(ident, ident->symbol, "identifier", "references");
- }
- if (auto* call = expr->As<ast::CallExpression>()) {
- if (call->target.name) {
- AddDependency(call->target.name, call->target.name->symbol,
- "function", "calls");
- }
- if (call->target.type) {
- TraverseType(call->target.type);
- }
- }
- if (auto* cast = expr->As<ast::BitcastExpression>()) {
- TraverseType(cast->type);
- }
+ Switch(
+ expr,
+ [&](const ast::IdentifierExpression* ident) {
+ AddDependency(ident, ident->symbol, "identifier", "references");
+ },
+ [&](const ast::CallExpression* call) {
+ if (call->target.name) {
+ AddDependency(call->target.name, call->target.name->symbol,
+ "function", "calls");
+ }
+ if (call->target.type) {
+ TraverseType(call->target.type);
+ }
+ },
+ [&](const ast::BitcastExpression* cast) {
+ TraverseType(cast->type);
+ });
return ast::TraverseAction::Descend;
});
}
@@ -324,50 +317,44 @@
/// Traverses the type node, performing symbol resolution and determining
/// global dependencies.
void TraverseType(const ast::Type* ty) {
- if (ty == nullptr) {
+ if (!ty) {
return;
}
- if (auto* arr = ty->As<ast::Array>()) {
- TraverseType(arr->type);
- TraverseExpression(arr->count);
- return;
- }
- if (auto* atomic = ty->As<ast::Atomic>()) {
- TraverseType(atomic->type);
- return;
- }
- if (auto* mat = ty->As<ast::Matrix>()) {
- TraverseType(mat->type);
- return;
- }
- if (auto* ptr = ty->As<ast::Pointer>()) {
- TraverseType(ptr->type);
- return;
- }
- if (auto* tn = ty->As<ast::TypeName>()) {
- AddDependency(tn, tn->name, "type", "references");
- return;
- }
- if (auto* vec = ty->As<ast::Vector>()) {
- TraverseType(vec->type);
- return;
- }
- if (auto* tex = ty->As<ast::SampledTexture>()) {
- TraverseType(tex->type);
- return;
- }
- if (auto* tex = ty->As<ast::MultisampledTexture>()) {
- TraverseType(tex->type);
- return;
- }
- if (ty->IsAnyOf<ast::Void, ast::Bool, ast::I32, ast::U32, ast::F32,
- ast::DepthTexture, ast::DepthMultisampledTexture,
- ast::StorageTexture, ast::ExternalTexture,
- ast::Sampler>()) {
- return;
- }
-
- UnhandledNode(diagnostics_, ty);
+ Switch(
+ ty, //
+ [&](const ast::Array* arr) {
+ TraverseType(arr->type); //
+ TraverseExpression(arr->count);
+ },
+ [&](const ast::Atomic* atomic) { //
+ TraverseType(atomic->type);
+ },
+ [&](const ast::Matrix* mat) { //
+ TraverseType(mat->type);
+ },
+ [&](const ast::Pointer* ptr) { //
+ TraverseType(ptr->type);
+ },
+ [&](const ast::TypeName* tn) { //
+ AddDependency(tn, tn->name, "type", "references");
+ },
+ [&](const ast::Vector* vec) { //
+ TraverseType(vec->type);
+ },
+ [&](const ast::SampledTexture* tex) { //
+ TraverseType(tex->type);
+ },
+ [&](const ast::MultisampledTexture* tex) { //
+ TraverseType(tex->type);
+ },
+ [&](Default) {
+ if (!ty->IsAnyOf<ast::Void, ast::Bool, ast::I32, ast::U32, ast::F32,
+ ast::DepthTexture, ast::DepthMultisampledTexture,
+ ast::StorageTexture, ast::ExternalTexture,
+ ast::Sampler>()) {
+ UnhandledNode(diagnostics_, ty);
+ }
+ });
}
/// Traverses the attribute list, performing symbol resolution and
@@ -490,17 +477,15 @@
/// @note will raise an ICE if the node is not a type, function or variable
/// declaration
Symbol SymbolOf(const ast::Node* node) const {
- if (auto* td = node->As<ast::TypeDecl>()) {
- return td->name;
- }
- if (auto* func = node->As<ast::Function>()) {
- return func->symbol;
- }
- if (auto* var = node->As<ast::Variable>()) {
- return var->symbol;
- }
- UnhandledNode(diagnostics_, node);
- return {};
+ return Switch(
+ node, //
+ [&](const ast::TypeDecl* td) { return td->name; },
+ [&](const ast::Function* func) { return func->symbol; },
+ [&](const ast::Variable* var) { return var->symbol; },
+ [&](Default) {
+ UnhandledNode(diagnostics_, node);
+ return Symbol{};
+ });
}
/// @param node the ast::Node of the global declaration
@@ -516,20 +501,16 @@
/// @note will raise an ICE if the node is not a type, function or variable
/// declaration
std::string KindOf(const ast::Node* node) {
- if (node->Is<ast::Struct>()) {
- return "struct";
- }
- if (node->Is<ast::Alias>()) {
- return "alias";
- }
- if (node->Is<ast::Function>()) {
- return "function";
- }
- if (auto* var = node->As<ast::Variable>()) {
- return var->is_const ? "let" : "var";
- }
- UnhandledNode(diagnostics_, node);
- return {};
+ return Switch(
+ node, //
+ [&](const ast::Struct*) { return "struct"; },
+ [&](const ast::Alias*) { return "alias"; },
+ [&](const ast::Function*) { return "function"; },
+ [&](const ast::Variable* var) { return var->is_const ? "let" : "var"; },
+ [&](Default) {
+ UnhandledNode(diagnostics_, node);
+ return "<error>";
+ });
}
/// Traverses `module`, collecting all the global declarations and populating