tint: static_assert that Sem.Get() template arg is needed
Change-Id: I91a73c22bd417fd9f32d45a1c91ffcb8f8d83d82
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118405
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/wrap_unary_operators.cc b/src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/wrap_unary_operators.cc
index 11c8cd6..9f8a14b 100644
--- a/src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/wrap_unary_operators.cc
+++ b/src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/wrap_unary_operators.cc
@@ -51,7 +51,7 @@
continue;
}
- const auto* expr_sem_node = program.Sem().Get<sem::ValueExpression>(expr_ast_node);
+ const auto* expr_sem_node = program.Sem().Get(expr_ast_node);
// Transformation applies only when the semantic node for the given
// expression is present.
diff --git a/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc b/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
index 9d65d78..df5c13c 100644
--- a/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
+++ b/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
@@ -50,7 +50,7 @@
return false;
}
- const auto* expression_sem_node = program.Sem().Get<sem::ValueExpression>(expression_ast_node);
+ const auto* expression_sem_node = program.Sem().Get(expression_ast_node);
if (!expression_sem_node) {
// Semantic information for the expression ast node is not present
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc
index 160b3b6..da78956 100644
--- a/src/tint/inspector/inspector_test.cc
+++ b/src/tint/inspector/inspector_test.cc
@@ -1628,16 +1628,16 @@
EXPECT_EQ(result["v300"].value, 300u);
ASSERT_TRUE(result.count("a"));
- ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(a));
- EXPECT_EQ(result["a"], program_->Sem().Get<sem::GlobalVariable>(a)->OverrideId());
+ ASSERT_TRUE(program_->Sem().Get(a));
+ EXPECT_EQ(result["a"], program_->Sem().Get(a)->OverrideId());
ASSERT_TRUE(result.count("b"));
- ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(b));
- EXPECT_EQ(result["b"], program_->Sem().Get<sem::GlobalVariable>(b)->OverrideId());
+ ASSERT_TRUE(program_->Sem().Get(b));
+ EXPECT_EQ(result["b"], program_->Sem().Get(b)->OverrideId());
ASSERT_TRUE(result.count("c"));
- ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(c));
- EXPECT_EQ(result["c"], program_->Sem().Get<sem::GlobalVariable>(c)->OverrideId());
+ ASSERT_TRUE(program_->Sem().Get(c));
+ EXPECT_EQ(result["c"], program_->Sem().Get(c)->OverrideId());
}
TEST_F(InspectorGetStorageSizeTest, Empty) {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 179c339..7082057 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -881,9 +881,11 @@
continue;
}
+ auto* sem = sem_.Get(override);
+
OverrideId id;
if (ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
- id = builder_->Sem().Get<sem::GlobalVariable>(override)->OverrideId();
+ id = sem->OverrideId();
} else {
// No ID was specified, so allocate the next available ID.
while (!ids_exhausted && override_ids_.Contains(next_id)) {
@@ -899,7 +901,6 @@
increment_next_id();
}
- auto* sem = sem_.Get<sem::GlobalVariable>(override);
const_cast<sem::GlobalVariable*>(sem)->SetOverrideId(id);
}
return true;
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index dcd6b06..a71bebc 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -499,7 +499,7 @@
auto* ref = TypeOf(a)->As<type::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<type::Array>();
- auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+ auto* sem_override = Sem().Get(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
}
@@ -524,7 +524,7 @@
ASSERT_NE(ref_b, nullptr);
auto* ary_b = ref_b->StoreType()->As<type::Array>();
- auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+ auto* sem_override = Sem().Get(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary_a->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
EXPECT_EQ(ary_b->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
@@ -544,7 +544,7 @@
auto* ref = TypeOf(a)->As<type::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<type::Array>();
- auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+ auto* sem_override = Sem().Get(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(cnt)));
}
@@ -571,7 +571,7 @@
ASSERT_NE(ref_b, nullptr);
auto* ary_b = ref_b->StoreType()->As<type::Array>();
- auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+ auto* sem_override = Sem().Get(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary_a->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(a_cnt)));
EXPECT_EQ(ary_b->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(b_cnt)));
diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h
index d6a5130..ded9e59 100644
--- a/src/tint/resolver/sem_helper.h
+++ b/src/tint/resolver/sem_helper.h
@@ -52,10 +52,13 @@
/// @returns the resolved symbol (function, type or variable) for the given ast::Identifier or
/// ast::TypeName cast to the given semantic type.
/// @param node the node to retrieve
- template <typename SEM = CastableBase>
- SEM* ResolvedSymbol(const ast::Node* node) const {
- auto resolved = dependencies_.resolved_symbols.Find(node);
- return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
+ template <typename SEM = sem::Info::InferFromAST>
+ sem::Info::GetResultType<SEM, ast::Node>* ResolvedSymbol(const ast::Node* node) const {
+ if (auto resolved = dependencies_.resolved_symbols.Find(node)) {
+ auto* sem = builder_->Sem().Get<SEM>(*resolved);
+ return const_cast<sem::Info::GetResultType<SEM, ast::Node>*>(sem);
+ }
+ return nullptr;
}
/// @returns the resolved type of the ast::Expression `expr`
diff --git a/src/tint/sem/info.h b/src/tint/sem/info.h
index c97b014..43a158e 100644
--- a/src/tint/sem/info.h
+++ b/src/tint/sem/info.h
@@ -79,6 +79,9 @@
typename AST = CastableBase,
typename RESULT = GetResultType<SEM, AST>>
const RESULT* Get(const AST* ast_node) const {
+ static_assert(std::is_same_v<SEM, InferFromAST> ||
+ !traits::IsTypeOrDerived<SemanticNodeTypeFor<AST>, SEM>,
+ "explicit template argument is unnecessary");
if (ast_node && ast_node->node_id.value < nodes_.size()) {
return As<RESULT>(nodes_[ast_node->node_id.value]);
}
diff --git a/src/tint/transform/direct_variable_access.cc b/src/tint/transform/direct_variable_access.cc
index 2d1b514..9d3d37b 100644
--- a/src/tint/transform/direct_variable_access.cc
+++ b/src/tint/transform/direct_variable_access.cc
@@ -990,7 +990,7 @@
return nullptr; // Just clone the expression.
}
- auto* expr = sem.Get<sem::ValueExpression>(ast_expr);
+ auto* expr = sem.Get(ast_expr);
if (!expr) {
// No semantic node for the expression.
return nullptr; // Just clone the expression.
diff --git a/src/tint/transform/pad_structs.cc b/src/tint/transform/pad_structs.cc
index 9607851..449cde8 100644
--- a/src/tint/transform/pad_structs.cc
+++ b/src/tint/transform/pad_structs.cc
@@ -59,7 +59,7 @@
utils::Hashset<const ast::StructMember*, 8> padding_members;
ctx.ReplaceAll([&](const ast::Struct* ast_str) -> const ast::Struct* {
- auto* str = sem.Get<sem::Struct>(ast_str);
+ auto* str = sem.Get(ast_str);
if (!str || !str->IsHostShareable()) {
return nullptr;
}