[tint] Move override tracking into sem objects
Instead of the sem::Info.
This is required to resolve types without recursion.
Change-Id: I6613d0e19b99a910bddf75dabb47a71cc1713435
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/155144
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/ast/transform/single_entry_point.cc b/src/tint/lang/wgsl/ast/transform/single_entry_point.cc
index c61798d..8daf8d1 100644
--- a/src/tint/lang/wgsl/ast/transform/single_entry_point.cc
+++ b/src/tint/lang/wgsl/ast/transform/single_entry_point.cc
@@ -20,6 +20,7 @@
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
+#include "src/tint/lang/wgsl/sem/array.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/variable.h"
#include "src/tint/utils/rtti/switch.h"
@@ -73,13 +74,10 @@
decl, //
[&](const TypeDecl* ty) {
// Strip aliases that reference unused override declarations.
- if (auto* arr = sem.Get(ty)->As<core::type::Array>()) {
- auto* refs = sem.TransitivelyReferencedOverrides(arr);
- if (refs) {
- for (auto* o : *refs) {
- if (!referenced_vars.Contains(o)) {
- return;
- }
+ if (auto* arr = sem.Get(ty)->As<sem::Array>()) {
+ for (auto* o : arr->TransitivelyReferencedOverrides()) {
+ if (!referenced_vars.Contains(o)) {
+ return;
}
}
}
diff --git a/src/tint/lang/wgsl/resolver/override_test.cc b/src/tint/lang/wgsl/resolver/override_test.cc
index f575a4f..575d58a 100644
--- a/src/tint/lang/wgsl/resolver/override_test.cc
+++ b/src/tint/lang/wgsl/resolver/override_test.cc
@@ -15,6 +15,7 @@
#include "src/tint/lang/wgsl/resolver/resolver.h"
#include "src/tint/lang/wgsl/resolver/resolver_helper_test.h"
+#include "src/tint/lang/wgsl/sem/array.h"
using namespace tint::core::number_suffixes; // NOLINT
@@ -134,15 +135,13 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
{
- auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(b));
- ASSERT_NE(r, nullptr);
- auto& refs = *r;
+ auto refs = Sem().Get(b)->TransitivelyReferencedOverrides();
ASSERT_EQ(refs.Length(), 1u);
EXPECT_EQ(refs[0], Sem().Get(a));
}
{
- auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals();
+ auto refs = Sem().Get(func)->TransitivelyReferencedGlobals();
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
@@ -161,9 +160,7 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
{
- auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get<sem::GlobalVariable>(b));
- ASSERT_NE(r, nullptr);
- auto& refs = *r;
+ auto refs = Sem().Get<sem::GlobalVariable>(b)->TransitivelyReferencedOverrides();
ASSERT_EQ(refs.Length(), 1u);
EXPECT_EQ(refs[0], Sem().Get(a));
}
@@ -201,7 +198,6 @@
auto* a = Override("a", ty.i32());
auto* b = Override("b", ty.i32(), Mul(2_a, "a"));
auto* arr = GlobalVar("arr", core::AddressSpace::kWorkgroup, ty.array(ty.i32(), Mul(2_a, "b")));
- auto arr_ty = arr->type;
Override("unused", ty.i32(), Expr(1_a));
auto* func = Func("foo", tint::Empty, ty.void_(),
Vector{
@@ -210,26 +206,25 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* global = Sem().Get<sem::GlobalVariable>(arr);
+ ASSERT_NE(global, nullptr);
+ auto* arr_ty = global->Type()->UnwrapRef()->As<sem::Array>();
+ ASSERT_NE(arr_ty, nullptr);
+
{
- auto* r = Sem().TransitivelyReferencedOverrides(TypeOf(arr_ty));
- ASSERT_NE(r, nullptr);
- auto& refs = *r;
+ auto refs = global->TransitivelyReferencedOverrides();
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
}
-
{
- auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get<sem::GlobalVariable>(arr));
- ASSERT_NE(r, nullptr);
- auto& refs = *r;
+ auto refs = arr_ty->TransitivelyReferencedOverrides();
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
}
-
{
- auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals();
+ auto refs = Sem().Get(func)->TransitivelyReferencedGlobals();
ASSERT_EQ(refs.Length(), 3u);
EXPECT_EQ(refs[0], Sem().Get(arr));
EXPECT_EQ(refs[1], Sem().Get(b));
@@ -242,7 +237,6 @@
auto* b = Override("b", ty.i32(), Mul(2_a, "a"));
Alias("arr_ty", ty.array(ty.i32(), Mul(2_a, "b")));
auto* arr = GlobalVar("arr", core::AddressSpace::kWorkgroup, ty("arr_ty"));
- auto arr_ty = arr->type;
Override("unused", ty.i32(), Expr(1_a));
auto* func = Func("foo", tint::Empty, ty.void_(),
Vector{
@@ -251,26 +245,25 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
+ auto* global = Sem().Get<sem::GlobalVariable>(arr);
+ ASSERT_NE(global, nullptr);
+ auto* arr_ty = global->Type()->UnwrapRef()->As<sem::Array>();
+ ASSERT_NE(arr_ty, nullptr);
+
{
- auto* r = Sem().TransitivelyReferencedOverrides(TypeOf(arr_ty));
- ASSERT_NE(r, nullptr);
- auto& refs = *r;
+ auto refs = global->TransitivelyReferencedOverrides();
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
}
-
{
- auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get<sem::GlobalVariable>(arr));
- ASSERT_NE(r, nullptr);
- auto& refs = *r;
+ auto refs = arr_ty->TransitivelyReferencedOverrides();
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
}
-
{
- auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals();
+ auto refs = Sem().Get(func)->TransitivelyReferencedGlobals();
ASSERT_EQ(refs.Length(), 3u);
EXPECT_EQ(refs[0], Sem().Get(arr));
EXPECT_EQ(refs[1], Sem().Get(b));
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index ff8f66b..29a88ca 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -308,6 +308,13 @@
b.Sem().Add(v, sem);
sem->SetStage(core::EvaluationStage::kOverride);
+ on_transitively_reference_global_.Push([&](const sem::GlobalVariable* ref) {
+ if (ref->Declaration()->Is<ast::Override>()) {
+ sem->AddTransitivelyReferencedOverride(ref);
+ }
+ });
+ TINT_DEFER(on_transitively_reference_global_.Pop());
+
// If the variable has a declared type, resolve it.
const core::type::Type* ty = nullptr;
if (v->type) {
@@ -403,8 +410,10 @@
sem::Variable* Resolver::Const(const ast::Const* c, bool is_global) {
sem::Variable* sem = nullptr;
+ sem::GlobalVariable* global = nullptr;
if (is_global) {
- sem = b.create<sem::GlobalVariable>(c);
+ global = b.create<sem::GlobalVariable>(c);
+ sem = global;
} else {
sem = b.create<sem::LocalVariable>(c, current_statement_);
}
@@ -487,6 +496,19 @@
sem->SetStage(core::EvaluationStage::kRuntime);
b.Sem().Add(var, sem);
+ if (is_global) {
+ on_transitively_reference_global_.Push([&](const sem::GlobalVariable* ref) {
+ if (ref->Declaration()->Is<ast::Override>()) {
+ global->AddTransitivelyReferencedOverride(ref);
+ }
+ });
+ }
+ TINT_DEFER({
+ if (is_global) {
+ on_transitively_reference_global_.Pop();
+ }
+ });
+
// If the variable has a declared type, resolve it.
const core::type::Type* storage_ty = nullptr;
if (auto ty = var->type) {
@@ -881,9 +903,6 @@
}
sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
- UniqueVector<const sem::GlobalVariable*, 4> transitively_referenced_overrides;
- TINT_SCOPED_ASSIGNMENT(resolved_overrides_, &transitively_referenced_overrides);
-
auto* sem = As<sem::GlobalVariable>(Variable(v, /* is_global */ true));
if (!sem) {
return nullptr;
@@ -897,20 +916,6 @@
return nullptr;
}
- // Track the pipeline-overridable constants that are transitively referenced by this
- // variable.
- for (auto* var : transitively_referenced_overrides) {
- b.Sem().AddTransitivelyReferencedOverride(sem, var);
- }
- if (auto* arr = sem->Type()->UnwrapRef()->As<sem::Array>()) {
- auto* refs = b.Sem().TransitivelyReferencedOverrides(arr);
- if (refs) {
- for (auto* var : *refs) {
- b.Sem().AddTransitivelyReferencedOverride(sem, var);
- }
- }
- }
-
return sem;
}
@@ -943,6 +948,11 @@
b.Sem().Add(decl, func);
TINT_SCOPED_ASSIGNMENT(current_function_, func);
+ on_transitively_reference_global_.Push([&](const sem::GlobalVariable* ref) { //
+ func->AddDirectlyReferencedGlobal(ref);
+ });
+ TINT_DEFER(on_transitively_reference_global_.Pop());
+
validator_.DiagnosticFilters().Push();
TINT_DEFER(validator_.DiagnosticFilters().Pop());
@@ -1566,6 +1576,14 @@
}
core::type::Type* Resolver::Type(const ast::Expression* ast) {
+ Vector<const sem::GlobalVariable*, 4> referenced_overrides;
+ on_transitively_reference_global_.Push([&](const sem::GlobalVariable* ref) {
+ if (ref->Declaration()->Is<ast::Override>()) {
+ referenced_overrides.Push(ref);
+ }
+ });
+ TINT_DEFER(on_transitively_reference_global_.Pop());
+
auto* type_expr = TypeExpression(ast);
if (TINT_UNLIKELY(!type_expr)) {
return nullptr;
@@ -1581,6 +1599,13 @@
ast->source.End());
return nullptr;
}
+
+ if (auto* arr = type->As<sem::Array>()) {
+ for (auto* ref : referenced_overrides) {
+ arr->AddTransitivelyReferencedOverride(ref);
+ }
+ }
+
return type;
}
@@ -2750,9 +2775,6 @@
}
core::type::Type* Resolver::Array(const ast::Identifier* ident) {
- UniqueVector<const sem::GlobalVariable*, 4> transitively_referenced_overrides;
- TINT_SCOPED_ASSIGNMENT(resolved_overrides_, &transitively_referenced_overrides);
-
auto* tmpl_ident = ident->As<ast::TemplatedIdentifier>();
if (!tmpl_ident) {
// 'array' has no template arguments, so return an incomplete type.
@@ -2798,11 +2820,6 @@
}
}
- // Track the pipeline-overridable constants that are transitively referenced by this
- // array type.
- for (auto* var : transitively_referenced_overrides) {
- b.Sem().AddTransitivelyReferencedOverride(out, var);
- }
return out;
}
@@ -3233,38 +3250,17 @@
}
}
- auto* global = variable->As<sem::GlobalVariable>();
- if (current_function_) {
- if (global) {
- current_function_->AddDirectlyReferencedGlobal(global);
- auto* refs = b.Sem().TransitivelyReferencedOverrides(global);
- if (refs) {
- for (auto* var : *refs) {
- current_function_->AddTransitivelyReferencedGlobal(var);
- }
- }
+ if (auto* global = variable->As<sem::GlobalVariable>()) {
+ for (auto& fn : on_transitively_reference_global_) {
+ fn(global);
}
- } else if (variable->Declaration()->Is<ast::Override>()) {
- if (resolved_overrides_) {
- // Track the reference to this pipeline-overridable constant and any other
- // pipeline-overridable constants that it references.
- resolved_overrides_->Add(global);
- auto* refs = b.Sem().TransitivelyReferencedOverrides(global);
- if (refs) {
- for (auto* var : *refs) {
- resolved_overrides_->Add(var);
- }
- }
+ if (!current_function_ && variable->Declaration()->Is<ast::Var>()) {
+ // Use of a module-scope 'var' outside of a function.
+ std::string desc = "var '" + ident->symbol.Name() + "' ";
+ AddError(desc + "cannot be referenced at module-scope", expr->source);
+ AddNote(desc + "declared here", variable->Declaration()->source);
+ return nullptr;
}
- } else if (variable->Declaration()->Is<ast::Var>()) {
- // Use of a module-scope 'var' outside of a function.
- // Note: The spec is currently vague around the rules here. See
- // https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when
- // resolved.
- std::string desc = "var '" + ident->symbol.Name() + "' ";
- AddError(desc + "cannot be referenced at module-scope", expr->source);
- AddNote(desc + "declared here", variable->Declaration()->source);
- return nullptr;
}
variable->AddUser(user);
@@ -3275,6 +3271,16 @@
if (!TINT_LIKELY(CheckNotTemplated("type", ident))) {
return nullptr;
}
+
+ // Notify callers of all transitively referenced globals.
+ if (auto* arr = ty->As<sem::Array>()) {
+ for (auto& fn : on_transitively_reference_global_) {
+ for (auto* ref : arr->TransitivelyReferencedOverrides()) {
+ fn(ref);
+ }
+ }
+ }
+
return b.create<sem::TypeExpression>(expr, current_statement_, ty);
},
[&](const sem::Function* fn) -> sem::FunctionExpression* {
@@ -4094,10 +4100,10 @@
core::type::Type* Resolver::Alias(const ast::Alias* alias) {
auto* ty = Type(alias->type);
- if (!ty) {
+ if (TINT_UNLIKELY(!ty)) {
return nullptr;
}
- if (!validator_.Alias(alias)) {
+ if (TINT_UNLIKELY(!validator_.Alias(alias))) {
return nullptr;
}
return ty;
diff --git a/src/tint/lang/wgsl/resolver/resolver.h b/src/tint/lang/wgsl/resolver/resolver.h
index 45e487f..2e1cbde 100644
--- a/src/tint/lang/wgsl/resolver/resolver.h
+++ b/src/tint/lang/wgsl/resolver/resolver.h
@@ -683,10 +683,9 @@
Hashmap<StructConstructorSig, sem::CallTarget*, 8> struct_ctors_;
sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
- sem::GlobalVariable* current_global_var_ = nullptr;
sem::CompoundStatement* current_compound_statement_ = nullptr;
+ Vector<std::function<void(const sem::GlobalVariable*)>, 4> on_transitively_reference_global_;
uint32_t current_scoping_depth_ = 0;
- UniqueVector<const sem::GlobalVariable*, 4>* resolved_overrides_ = nullptr;
Hashset<TypeAndAddressSpace, 8> valid_type_storage_layouts_;
Hashmap<const ast::Expression*, const ast::BinaryExpression*, 8> logical_binary_lhs_to_parent_;
Hashset<const ast::Expression*, 8> skip_const_eval_;
diff --git a/src/tint/lang/wgsl/sem/array.cc b/src/tint/lang/wgsl/sem/array.cc
index 3793bb6..a2aae60 100644
--- a/src/tint/lang/wgsl/sem/array.cc
+++ b/src/tint/lang/wgsl/sem/array.cc
@@ -14,6 +14,8 @@
#include "src/tint/lang/wgsl/sem/array.h"
+#include "src/tint/lang/wgsl/sem/variable.h"
+
TINT_INSTANTIATE_TYPEINFO(tint::sem::Array);
namespace tint::sem {
@@ -28,4 +30,11 @@
Array::~Array() = default;
+void Array::AddTransitivelyReferencedOverride(const GlobalVariable* var) {
+ transitively_referenced_overrides_.Add(var);
+ for (auto* ref : var->TransitivelyReferencedOverrides()) {
+ AddTransitivelyReferencedOverride(ref);
+ }
+}
+
} // namespace tint::sem
diff --git a/src/tint/lang/wgsl/sem/array.h b/src/tint/lang/wgsl/sem/array.h
index 82c635a..e634a0f 100644
--- a/src/tint/lang/wgsl/sem/array.h
+++ b/src/tint/lang/wgsl/sem/array.h
@@ -48,6 +48,18 @@
/// Destructor
~Array() override;
+
+ /// Records that this variable (transitively) references the given override variable.
+ /// @param var the module-scope override variable
+ void AddTransitivelyReferencedOverride(const GlobalVariable* var);
+
+ /// @returns all transitively referenced override variables
+ VectorRef<const GlobalVariable*> TransitivelyReferencedOverrides() const {
+ return transitively_referenced_overrides_;
+ }
+
+ private:
+ UniqueVector<const GlobalVariable*, 4> transitively_referenced_overrides_;
};
} // namespace tint::sem
diff --git a/src/tint/lang/wgsl/sem/function.cc b/src/tint/lang/wgsl/sem/function.cc
index 1d1b9e3..bb6a8a2 100644
--- a/src/tint/lang/wgsl/sem/function.cc
+++ b/src/tint/lang/wgsl/sem/function.cc
@@ -52,6 +52,14 @@
return ret;
}
+void Function::AddTransitivelyReferencedGlobal(const sem::GlobalVariable* global) {
+ if (transitively_referenced_globals_.Add(global)) {
+ for (auto* ref : global->TransitivelyReferencedOverrides()) {
+ AddTransitivelyReferencedGlobal(ref);
+ }
+ }
+}
+
Function::VariableBindings Function::TransitivelyReferencedUniformVariables() const {
VariableBindings ret;
diff --git a/src/tint/lang/wgsl/sem/function.h b/src/tint/lang/wgsl/sem/function.h
index 585ede0..c0da248 100644
--- a/src/tint/lang/wgsl/sem/function.h
+++ b/src/tint/lang/wgsl/sem/function.h
@@ -86,11 +86,11 @@
}
/// Records that this function directly references the given global variable.
- /// Note: Implicitly adds this global to the transtively-called globals.
+ /// Note: Implicitly adds this global to the transitively-called globals.
/// @param global the module-scope variable
void AddDirectlyReferencedGlobal(const sem::GlobalVariable* global) {
directly_referenced_globals_.Add(global);
- transitively_referenced_globals_.Add(global);
+ AddTransitivelyReferencedGlobal(global);
}
/// @returns all transitively referenced global variables
@@ -101,9 +101,7 @@
/// Records that this function transitively references the given global
/// variable.
/// @param global the module-scoped variable
- void AddTransitivelyReferencedGlobal(const sem::GlobalVariable* global) {
- transitively_referenced_globals_.Add(global);
- }
+ void AddTransitivelyReferencedGlobal(const sem::GlobalVariable* global);
/// @returns the list of functions that this function transitively calls.
const UniqueVector<const Function*, 8>& TransitivelyCalledFunctions() const {
diff --git a/src/tint/lang/wgsl/sem/info.h b/src/tint/lang/wgsl/sem/info.h
index 5a41a3f..a4ab07c 100644
--- a/src/tint/lang/wgsl/sem/info.h
+++ b/src/tint/lang/wgsl/sem/info.h
@@ -51,9 +51,6 @@
using GetResultType =
std::conditional_t<std::is_same<SEM, InferFromAST>::value, SemanticNodeTypeFor<AST>, SEM>;
- /// Alias to a unique vector of transitively referenced global variables
- using TransitivelyReferenced = UniqueVector<const GlobalVariable*, 4>;
-
/// Constructor
Info();
@@ -138,25 +135,6 @@
/// @returns the semantic module.
const sem::Module* Module() const { return module_; }
- /// Records that this variable (transitively) references the given override variable.
- /// @param from the item the variable is referenced from
- /// @param var the module-scope override variable
- void AddTransitivelyReferencedOverride(const CastableBase* from, const GlobalVariable* var) {
- if (referenced_overrides_.count(from) == 0) {
- referenced_overrides_.insert({from, TransitivelyReferenced{}});
- }
- referenced_overrides_[from].Add(var);
- }
-
- /// @param from the key to look up
- /// @returns all transitively referenced override variables or nullptr if none set
- const TransitivelyReferenced* TransitivelyReferencedOverrides(const CastableBase* from) const {
- if (referenced_overrides_.count(from) == 0) {
- return nullptr;
- }
- return &referenced_overrides_.at(from);
- }
-
/// Determines the severity of a filterable diagnostic rule for the AST node `ast_node`.
/// @param ast_node the AST node
/// @param rule the diagnostic rule
@@ -167,8 +145,6 @@
private:
// AST node index to semantic node
std::vector<const CastableBase*> nodes_;
- // Lists transitively referenced overrides for the given item
- std::unordered_map<const CastableBase*, TransitivelyReferenced> referenced_overrides_;
// The semantic module
sem::Module* module_ = nullptr;
};
diff --git a/src/tint/lang/wgsl/sem/variable.cc b/src/tint/lang/wgsl/sem/variable.cc
index d86f4ec..54849f2 100644
--- a/src/tint/lang/wgsl/sem/variable.cc
+++ b/src/tint/lang/wgsl/sem/variable.cc
@@ -41,6 +41,14 @@
GlobalVariable::~GlobalVariable() = default;
+void GlobalVariable::AddTransitivelyReferencedOverride(const GlobalVariable* var) {
+ if (transitively_referenced_overrides_.Add(var)) {
+ for (auto* ref : var->TransitivelyReferencedOverrides()) {
+ AddTransitivelyReferencedOverride(ref);
+ }
+ }
+}
+
Parameter::Parameter(const ast::Parameter* declaration,
uint32_t index /* = 0 */,
const core::type::Type* type /* = nullptr */,
diff --git a/src/tint/lang/wgsl/sem/variable.h b/src/tint/lang/wgsl/sem/variable.h
index 82d9603..7a7b826 100644
--- a/src/tint/lang/wgsl/sem/variable.h
+++ b/src/tint/lang/wgsl/sem/variable.h
@@ -176,11 +176,21 @@
/// @returns the index value for the parameter, if set
std::optional<uint32_t> Index() const { return index_; }
+ /// Records that this variable (transitively) references the given override variable.
+ /// @param var the module-scope override variable
+ void AddTransitivelyReferencedOverride(const GlobalVariable* var);
+
+ /// @returns all transitively referenced override variables
+ VectorRef<const GlobalVariable*> TransitivelyReferencedOverrides() const {
+ return transitively_referenced_overrides_;
+ }
+
private:
std::optional<tint::BindingPoint> binding_point_;
tint::OverrideId override_id_;
std::optional<uint32_t> location_;
std::optional<uint32_t> index_;
+ UniqueVector<const GlobalVariable*, 4> transitively_referenced_overrides_;
};
/// Parameter is a function parameter