tint/resolver: Bring back enum suggestions
The dependency graph no longer errors if a symbol cannot be resolved, instead the ResolvedIdentifier now has an unresolved variant.
This is required as the second resolve phase only has the full context of the identifier usage, to provide the hints.
Also: Split Slice out of the utils/vector.h, so it can be used as a lightweight view over static data.
Fixed: tint:1842
Change-Id: I31fa7697790be24c35b7e4fab5ca903c8a7afbba
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/121020
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc
index e0f0656..0b8f728 100644
--- a/src/tint/resolver/dependency_graph.cc
+++ b/src/tint/resolver/dependency_graph.cc
@@ -79,8 +79,6 @@
struct DependencyInfo {
/// The source of the symbol that forms the dependency
Source source;
- /// A string describing how the dependency is referenced. e.g. 'calls'
- const char* action = nullptr;
};
/// DependencyEdge describes the two Globals used to define a dependency
@@ -174,12 +172,12 @@
Declare(str->name->symbol, str);
for (auto* member : str->members) {
TraverseAttributes(member->attributes);
- TraverseTypeExpression(member->type);
+ TraverseExpression(member->type);
}
},
[&](const ast::Alias* alias) {
Declare(alias->name->symbol, alias);
- TraverseTypeExpression(alias->type);
+ TraverseExpression(alias->type);
},
[&](const ast::Function* func) {
Declare(func->name->symbol, func);
@@ -195,9 +193,7 @@
[&](const ast::Enable*) {
// Enable directives do not affect the dependency graph.
},
- [&](const ast::ConstAssert* assertion) {
- TraverseValueExpression(assertion->condition);
- },
+ [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
[&](Default) { UnhandledNode(diagnostics_, global->node); });
}
@@ -205,12 +201,12 @@
/// Traverses the variable, performing symbol resolution.
void TraverseVariable(const ast::Variable* v) {
if (auto* var = v->As<ast::Var>()) {
- TraverseAddressSpaceExpression(var->declared_address_space);
- TraverseAccessExpression(var->declared_access);
+ TraverseExpression(var->declared_address_space);
+ TraverseExpression(var->declared_access);
}
- TraverseTypeExpression(v->type);
+ TraverseExpression(v->type);
TraverseAttributes(v->attributes);
- TraverseValueExpression(v->initializer);
+ TraverseExpression(v->initializer);
}
/// Traverses the function, performing symbol resolution and determining global dependencies.
@@ -222,10 +218,10 @@
// with the same identifier as its type.
for (auto* param : func->params) {
TraverseAttributes(param->attributes);
- TraverseTypeExpression(param->type);
+ TraverseExpression(param->type);
}
// Resolve the return type
- TraverseTypeExpression(func->return_type);
+ TraverseExpression(func->return_type);
// Push the scope stack for the parameters and function body.
scope_stack_.Push();
@@ -259,29 +255,29 @@
Switch(
stmt, //
[&](const ast::AssignmentStatement* a) {
- TraverseValueExpression(a->lhs);
- TraverseValueExpression(a->rhs);
+ TraverseExpression(a->lhs);
+ TraverseExpression(a->rhs);
},
[&](const ast::BlockStatement* b) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatements(b->statements);
},
- [&](const ast::BreakIfStatement* b) { TraverseValueExpression(b->condition); },
- [&](const ast::CallStatement* r) { TraverseValueExpression(r->expr); },
+ [&](const ast::BreakIfStatement* b) { TraverseExpression(b->condition); },
+ [&](const ast::CallStatement* r) { TraverseExpression(r->expr); },
[&](const ast::CompoundAssignmentStatement* a) {
- TraverseValueExpression(a->lhs);
- TraverseValueExpression(a->rhs);
+ TraverseExpression(a->lhs);
+ TraverseExpression(a->rhs);
},
[&](const ast::ForLoopStatement* l) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatement(l->initializer);
- TraverseValueExpression(l->condition);
+ TraverseExpression(l->condition);
TraverseStatement(l->continuing);
TraverseStatement(l->body);
},
- [&](const ast::IncrementDecrementStatement* i) { TraverseValueExpression(i->lhs); },
+ [&](const ast::IncrementDecrementStatement* i) { TraverseExpression(i->lhs); },
[&](const ast::LoopStatement* l) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
@@ -289,18 +285,18 @@
TraverseStatement(l->continuing);
},
[&](const ast::IfStatement* i) {
- TraverseValueExpression(i->condition);
+ TraverseExpression(i->condition);
TraverseStatement(i->body);
if (i->else_statement) {
TraverseStatement(i->else_statement);
}
},
- [&](const ast::ReturnStatement* r) { TraverseValueExpression(r->value); },
+ [&](const ast::ReturnStatement* r) { TraverseExpression(r->value); },
[&](const ast::SwitchStatement* s) {
- TraverseValueExpression(s->condition);
+ TraverseExpression(s->condition);
for (auto* c : s->body) {
for (auto* sel : c->selectors) {
- TraverseValueExpression(sel->expr);
+ TraverseExpression(sel->expr);
}
TraverseStatement(c->body);
}
@@ -315,12 +311,10 @@
[&](const ast::WhileStatement* w) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
- TraverseValueExpression(w->condition);
+ TraverseExpression(w->condition);
TraverseStatement(w->body);
},
- [&](const ast::ConstAssert* assertion) {
- TraverseValueExpression(assertion->condition);
- },
+ [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
[&](Default) {
if (TINT_UNLIKELY((!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement>()))) {
@@ -340,70 +334,28 @@
}
}
- /// Traverses the expression @p root_expr for the intended use as a value, performing symbol
- /// resolution and determining global dependencies.
- void TraverseValueExpression(const ast::Expression* root) {
- TraverseExpression(root, "identifier", "references");
- }
-
- /// Traverses the expression @p root_expr for the intended use as a type, performing symbol
- /// resolution and determining global dependencies.
- void TraverseTypeExpression(const ast::Expression* root) {
- TraverseExpression(root, "type", "references");
- }
-
- /// Traverses the expression @p root_expr for the intended use as an address space, performing
- /// symbol resolution and determining global dependencies.
- void TraverseAddressSpaceExpression(const ast::Expression* root) {
- TraverseExpression(root, "address space", "references");
- }
-
- /// Traverses the expression @p root_expr for the intended use as an access, performing symbol
- /// resolution and determining global dependencies.
- void TraverseAccessExpression(const ast::Expression* root) {
- TraverseExpression(root, "access", "references");
- }
-
- /// Traverses the expression @p root_expr for the intended use as a call target, performing
- /// symbol resolution and determining global dependencies.
- void TraverseCallableExpression(const ast::Expression* root) {
- TraverseExpression(root, "function", "calls");
- }
-
/// Traverses the expression @p root_expr, performing symbol resolution and determining global
/// dependencies.
- void TraverseExpression(const ast::Expression* root_expr,
- const char* root_use,
- const char* root_action) {
+ void TraverseExpression(const ast::Expression* root_expr) {
if (!root_expr) {
return;
}
- struct Pending {
- const ast::Expression* expr;
- const char* use;
- const char* action;
- };
- utils::Vector<Pending, 8> pending{{root_expr, root_use, root_action}};
+ utils::Vector<const ast::Expression*, 8> pending{root_expr};
while (!pending.IsEmpty()) {
- auto next = pending.Pop();
- ast::TraverseExpressions(next.expr, diagnostics_, [&](const ast::Expression* expr) {
+ ast::TraverseExpressions(pending.Pop(), diagnostics_, [&](const ast::Expression* expr) {
Switch(
expr,
[&](const ast::IdentifierExpression* e) {
- AddDependency(e->identifier, e->identifier->symbol, next.use, next.action);
+ AddDependency(e->identifier, e->identifier->symbol);
if (auto* tmpl_ident = e->identifier->As<ast::TemplatedIdentifier>()) {
for (auto* arg : tmpl_ident->arguments) {
- pending.Push({arg, "identifier", "references"});
+ pending.Push(arg);
}
}
},
- [&](const ast::CallExpression* call) {
- TraverseCallableExpression(call->target);
- },
- [&](const ast::BitcastExpression* cast) {
- TraverseTypeExpression(cast->type);
- });
+ [&](const ast::CallExpression* call) { TraverseExpression(call->target); },
+ [&](const ast::BitcastExpression* cast) { TraverseExpression(cast->type); });
return ast::TraverseAction::Descend;
});
}
@@ -423,42 +375,42 @@
bool handled = Switch(
attr,
[&](const ast::BindingAttribute* binding) {
- TraverseValueExpression(binding->expr);
+ TraverseExpression(binding->expr);
return true;
},
[&](const ast::BuiltinAttribute* builtin) {
- TraverseExpression(builtin->builtin, "builtin", "references");
+ TraverseExpression(builtin->builtin);
return true;
},
[&](const ast::GroupAttribute* group) {
- TraverseValueExpression(group->expr);
+ TraverseExpression(group->expr);
return true;
},
[&](const ast::IdAttribute* id) {
- TraverseValueExpression(id->expr);
+ TraverseExpression(id->expr);
return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
- TraverseExpression(interpolate->type, "interpolation type", "references");
- TraverseExpression(interpolate->sampling, "interpolation sampling", "references");
+ TraverseExpression(interpolate->type);
+ TraverseExpression(interpolate->sampling);
return true;
},
[&](const ast::LocationAttribute* loc) {
- TraverseValueExpression(loc->expr);
+ TraverseExpression(loc->expr);
return true;
},
[&](const ast::StructMemberAlignAttribute* align) {
- TraverseValueExpression(align->expr);
+ TraverseExpression(align->expr);
return true;
},
[&](const ast::StructMemberSizeAttribute* size) {
- TraverseValueExpression(size->expr);
+ TraverseExpression(size->expr);
return true;
},
[&](const ast::WorkgroupAttribute* wg) {
- TraverseValueExpression(wg->x);
- TraverseValueExpression(wg->y);
- TraverseValueExpression(wg->z);
+ TraverseExpression(wg->x);
+ TraverseExpression(wg->y);
+ TraverseExpression(wg->z);
return true;
});
if (handled) {
@@ -476,10 +428,7 @@
}
/// Adds the dependency from @p from to @p to, erroring if @p to cannot be resolved.
- void AddDependency(const ast::Identifier* from,
- Symbol to,
- const char* use,
- const char* action) {
+ void AddDependency(const ast::Identifier* from, Symbol to) {
auto* resolved = scope_stack_.Get(to);
if (!resolved) {
auto s = symbols_.NameFor(to);
@@ -521,13 +470,14 @@
return;
}
- UnknownSymbol(to, from->source, use);
+ // Unresolved.
+ graph_.resolved_identifiers.Add(from, UnresolvedIdentifier{s});
return;
}
if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
- DependencyInfo{from->source, action})) {
+ DependencyInfo{from->source})) {
current_global_->deps.Push(*global);
}
}
@@ -535,12 +485,6 @@
graph_.resolved_identifiers.Add(from, ResolvedIdentifier(resolved));
}
- /// Appends an error to the diagnostics that the given symbol cannot be resolved.
- void UnknownSymbol(Symbol name, Source source, const char* use) {
- AddError(diagnostics_, "unknown " + std::string(use) + ": '" + symbols_.NameFor(name) + "'",
- source);
- }
-
using VariableMap = utils::Hashmap<Symbol, const ast::Variable*, 32>;
const SymbolTable& symbols_;
const GlobalMap& globals_;
@@ -787,7 +731,7 @@
auto* to = (i + 1 < stack.Length()) ? stack[i + 1] : stack[loop_start];
auto info = DepInfoFor(from, to);
AddNote(diagnostics_,
- KindOf(from->node) + " '" + NameOf(from->node) + "' " + info.action + " " +
+ KindOf(from->node) + " '" + NameOf(from->node) + "' references " +
KindOf(to->node) + " '" + NameOf(to->node) + "' here",
info.source);
}
@@ -831,8 +775,7 @@
/// Global map, keyed by name. Populated by GatherGlobals().
GlobalMap globals_;
- /// Map of DependencyEdge to DependencyInfo. Populated by
- /// DetermineDependencies().
+ /// Map of DependencyEdge to DependencyInfo. Populated by DetermineDependencies().
DependencyEdges dependency_edges_;
/// Globals in declaration order. Populated by GatherGlobals().
@@ -857,9 +800,6 @@
}
std::string ResolvedIdentifier::String(const SymbolTable& symbols, diag::List& diagnostics) const {
- if (!Resolved()) {
- return "<unresolved symbol>";
- }
if (auto* node = Node()) {
return Switch(
node,
@@ -911,6 +851,10 @@
if (auto fmt = TexelFormat(); fmt != builtin::TexelFormat::kUndefined) {
return "texel format '" + utils::ToString(fmt) + "'";
}
+ if (auto* unresolved = Unresolved()) {
+ return "unresolved identifier '" + unresolved->name + "'";
+ }
+
TINT_UNREACHABLE(Resolver, diagnostics) << "unhandled ResolvedIdentifier";
return "<unknown>";
}