[sem] Move TransitivelyReferencedOverrides to sem::Info.
This CL pulls the TransitivelyReferencedOverrides from sem::Array and
sem::GlobalVariable up to the sem::Info.
Moving this data outside of sem::Array removes one of the references to
non-Type sem content.
Bug: tint:1718
Change-Id: I40c1c8b2d5ec60dc2723b56cc30cd436e9b7e997
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112324
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/resolver/override_test.cc b/src/tint/resolver/override_test.cc
index 132bd55..a97c92d 100644
--- a/src/tint/resolver/override_test.cc
+++ b/src/tint/resolver/override_test.cc
@@ -142,7 +142,9 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
{
- auto& refs = Sem().Get(b)->TransitivelyReferencedOverrides();
+ auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(b));
+ ASSERT_NE(r, nullptr);
+ auto& refs = *r;
ASSERT_EQ(refs.Length(), 1u);
EXPECT_EQ(refs[0], Sem().Get(a));
}
@@ -167,7 +169,9 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
{
- auto& refs = Sem().Get<sem::GlobalVariable>(b)->TransitivelyReferencedOverrides();
+ auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get<sem::GlobalVariable>(b));
+ ASSERT_NE(r, nullptr);
+ auto& refs = *r;
ASSERT_EQ(refs.Length(), 1u);
EXPECT_EQ(refs[0], Sem().Get(a));
}
@@ -215,14 +219,18 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
{
- auto& refs = Sem().Get(arr_ty)->TransitivelyReferencedOverrides();
+ auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(arr_ty));
+ ASSERT_NE(r, nullptr);
+ auto& refs = *r;
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
}
{
- auto& refs = Sem().Get<sem::GlobalVariable>(arr)->TransitivelyReferencedOverrides();
+ auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get<sem::GlobalVariable>(arr));
+ ASSERT_NE(r, nullptr);
+ auto& refs = *r;
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
@@ -251,14 +259,18 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
{
- auto& refs = Sem().Get<sem::Array>(arr_ty->type)->TransitivelyReferencedOverrides();
+ auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get<sem::Array>(arr_ty->type));
+ ASSERT_NE(r, nullptr);
+ auto& refs = *r;
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
}
{
- auto& refs = Sem().Get<sem::GlobalVariable>(arr)->TransitivelyReferencedOverrides();
+ auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get<sem::GlobalVariable>(arr));
+ ASSERT_NE(r, nullptr);
+ auto& refs = *r;
ASSERT_EQ(refs.Length(), 2u);
EXPECT_EQ(refs[0], Sem().Get(b));
EXPECT_EQ(refs[1], Sem().Get(a));
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 9fbc7f0..e5b188c 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -917,11 +917,14 @@
// Track the pipeline-overridable constants that are transitively referenced by this variable.
for (auto* var : transitively_referenced_overrides) {
- sem->AddTransitivelyReferencedOverride(var);
+ builder_->Sem().AddTransitivelyReferencedOverride(sem, var);
}
if (auto* arr = sem->Type()->UnwrapRef()->As<sem::Array>()) {
- for (auto* var : arr->TransitivelyReferencedOverrides()) {
- sem->AddTransitivelyReferencedOverride(var);
+ auto* refs = builder_->Sem().TransitivelyReferencedOverrides(arr);
+ if (refs) {
+ for (auto* var : *refs) {
+ builder_->Sem().AddTransitivelyReferencedOverride(sem, var);
+ }
}
}
@@ -2553,8 +2556,11 @@
if (current_function_) {
if (global) {
current_function_->AddDirectlyReferencedGlobal(global);
- for (auto* var : global->TransitivelyReferencedOverrides()) {
- current_function_->AddTransitivelyReferencedGlobal(var);
+ auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global);
+ if (refs) {
+ for (auto* var : *refs) {
+ current_function_->AddTransitivelyReferencedGlobal(var);
+ }
}
}
} else if (variable->Declaration()->Is<ast::Override>()) {
@@ -2562,8 +2568,11 @@
// Track the reference to this pipeline-overridable constant and any other
// pipeline-overridable constants that it references.
resolved_overrides_->Add(global);
- for (auto* var : global->TransitivelyReferencedOverrides()) {
- resolved_overrides_->Add(var);
+ auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global);
+ if (refs) {
+ for (auto* var : *refs) {
+ resolved_overrides_->Add(var);
+ }
}
}
} else if (variable->Declaration()->Is<ast::Var>()) {
@@ -2956,7 +2965,7 @@
// Track the pipeline-overridable constants that are transitively referenced by this array
// type.
for (auto* var : transitively_referenced_overrides) {
- out->AddTransitivelyReferencedOverride(var);
+ builder_->Sem().AddTransitivelyReferencedOverride(out, var);
}
return out;
diff --git a/src/tint/sem/array.h b/src/tint/sem/array.h
index 4047ae4..4d1ed7d 100644
--- a/src/tint/sem/array.h
+++ b/src/tint/sem/array.h
@@ -230,17 +230,6 @@
/// @returns true if this array is runtime sized
bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
- /// Records that this array type (transitively) references the given override variable.
- /// @param var the module-scope override variable
- void AddTransitivelyReferencedOverride(const GlobalVariable* var) {
- referenced_overrides_.Add(var);
- }
-
- /// @returns all transitively referenced override variables
- const utils::UniqueVector<const GlobalVariable*, 4>& TransitivelyReferencedOverrides() const {
- return referenced_overrides_;
- }
-
/// @param symbols the program's symbol table
/// @returns the name for this type that closely resembles how it would be
/// declared in WGSL.
@@ -253,7 +242,6 @@
const uint32_t size_;
const uint32_t stride_;
const uint32_t implicit_stride_;
- utils::UniqueVector<const GlobalVariable*, 4> referenced_overrides_;
};
} // namespace tint::sem
diff --git a/src/tint/sem/info.h b/src/tint/sem/info.h
index 894b408..28b41a6 100644
--- a/src/tint/sem/info.h
+++ b/src/tint/sem/info.h
@@ -24,6 +24,7 @@
#include "src/tint/debug.h"
#include "src/tint/sem/node.h"
#include "src/tint/sem/type_mappings.h"
+#include "src/tint/utils/unique_vector.h"
// Forward declarations
namespace tint::sem {
@@ -44,6 +45,9 @@
using GetResultType =
std::conditional_t<std::is_same<SEM, InferFromAST>::value, SemanticNodeTypeFor<AST>, SEM>;
+ /// Alias to a unique vector of transitively referenced global variables
+ using TransitivelyReferenced = utils::UniqueVector<const GlobalVariable*, 4>;
+
/// Constructor
Info();
@@ -117,9 +121,30 @@
/// @returns the semantic module.
const sem::Module* Module() const { return module_; }
+ /// Records that this variable (transitively) references the given override variable.
+ /// @param from the item the variable is referenced from
+ /// @param var the module-scope override variable
+ void AddTransitivelyReferencedOverride(const CastableBase* from, const GlobalVariable* var) {
+ if (referenced_overrides_.count(from) == 0) {
+ referenced_overrides_.insert({from, TransitivelyReferenced{}});
+ }
+ referenced_overrides_[from].Add(var);
+ }
+
+ /// @param from the key to look up
+ /// @returns all transitively referenced override variables or nullptr if none set
+ const TransitivelyReferenced* TransitivelyReferencedOverrides(const CastableBase* from) const {
+ if (referenced_overrides_.count(from) == 0) {
+ return nullptr;
+ }
+ return &referenced_overrides_.at(from);
+ }
+
private:
// AST node index to semantic node
std::vector<const sem::Node*> nodes_;
+ // Lists transitively referenced overrides for the given item
+ std::unordered_map<const CastableBase*, TransitivelyReferenced> referenced_overrides_;
// The semantic module
sem::Module* module_ = nullptr;
};
diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h
index 08dfa98..b977797 100644
--- a/src/tint/sem/variable.h
+++ b/src/tint/sem/variable.h
@@ -183,23 +183,11 @@
/// @returns the location value for the parameter, if set
std::optional<uint32_t> Location() const { return location_; }
- /// Records that this variable (transitively) references the given override variable.
- /// @param var the module-scope override variable
- void AddTransitivelyReferencedOverride(const GlobalVariable* var) {
- referenced_overrides_.Add(var);
- }
-
- /// @returns all transitively referenced override variables
- const utils::UniqueVector<const GlobalVariable*, 4>& TransitivelyReferencedOverrides() const {
- return referenced_overrides_;
- }
-
private:
const sem::BindingPoint binding_point_;
tint::OverrideId override_id_;
std::optional<uint32_t> location_;
- utils::UniqueVector<const GlobalVariable*, 4> referenced_overrides_;
};
/// Parameter is a function parameter
diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc
index 386631e..694cd2d 100644
--- a/src/tint/transform/single_entry_point.cc
+++ b/src/tint/transform/single_entry_point.cc
@@ -71,9 +71,12 @@
[&](const ast::TypeDecl* ty) {
// Strip aliases that reference unused override declarations.
if (auto* arr = sem.Get(ty)->As<sem::Array>()) {
- for (auto* o : arr->TransitivelyReferencedOverrides()) {
- if (!referenced_vars.Contains(o)) {
- return;
+ auto* refs = sem.TransitivelyReferencedOverrides(arr);
+ if (refs) {
+ for (auto* o : *refs) {
+ if (!referenced_vars.Contains(o)) {
+ return;
+ }
}
}
}