tint: static_assert that Sem.Get() template arg is needed

Change-Id: I91a73c22bd417fd9f32d45a1c91ffcb8f8d83d82
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118405
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/wrap_unary_operators.cc b/src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/wrap_unary_operators.cc
index 11c8cd6..9f8a14b 100644
--- a/src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/wrap_unary_operators.cc
+++ b/src/tint/fuzzers/tint_ast_fuzzer/mutation_finders/wrap_unary_operators.cc
@@ -51,7 +51,7 @@
             continue;
         }
 
-        const auto* expr_sem_node = program.Sem().Get<sem::ValueExpression>(expr_ast_node);
+        const auto* expr_sem_node = program.Sem().Get(expr_ast_node);
 
         // Transformation applies only when the semantic node for the given
         // expression is present.
diff --git a/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc b/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
index 9d65d78..df5c13c 100644
--- a/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
+++ b/src/tint/fuzzers/tint_ast_fuzzer/mutations/wrap_unary_operator.cc
@@ -50,7 +50,7 @@
         return false;
     }
 
-    const auto* expression_sem_node = program.Sem().Get<sem::ValueExpression>(expression_ast_node);
+    const auto* expression_sem_node = program.Sem().Get(expression_ast_node);
 
     if (!expression_sem_node) {
         // Semantic information for the expression ast node is not present
diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc
index 160b3b6..da78956 100644
--- a/src/tint/inspector/inspector_test.cc
+++ b/src/tint/inspector/inspector_test.cc
@@ -1628,16 +1628,16 @@
     EXPECT_EQ(result["v300"].value, 300u);
 
     ASSERT_TRUE(result.count("a"));
-    ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(a));
-    EXPECT_EQ(result["a"], program_->Sem().Get<sem::GlobalVariable>(a)->OverrideId());
+    ASSERT_TRUE(program_->Sem().Get(a));
+    EXPECT_EQ(result["a"], program_->Sem().Get(a)->OverrideId());
 
     ASSERT_TRUE(result.count("b"));
-    ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(b));
-    EXPECT_EQ(result["b"], program_->Sem().Get<sem::GlobalVariable>(b)->OverrideId());
+    ASSERT_TRUE(program_->Sem().Get(b));
+    EXPECT_EQ(result["b"], program_->Sem().Get(b)->OverrideId());
 
     ASSERT_TRUE(result.count("c"));
-    ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(c));
-    EXPECT_EQ(result["c"], program_->Sem().Get<sem::GlobalVariable>(c)->OverrideId());
+    ASSERT_TRUE(program_->Sem().Get(c));
+    EXPECT_EQ(result["c"], program_->Sem().Get(c)->OverrideId());
 }
 
 TEST_F(InspectorGetStorageSizeTest, Empty) {
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 179c339..7082057 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -881,9 +881,11 @@
             continue;
         }
 
+        auto* sem = sem_.Get(override);
+
         OverrideId id;
         if (ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
-            id = builder_->Sem().Get<sem::GlobalVariable>(override)->OverrideId();
+            id = sem->OverrideId();
         } else {
             // No ID was specified, so allocate the next available ID.
             while (!ids_exhausted && override_ids_.Contains(next_id)) {
@@ -899,7 +901,6 @@
             increment_next_id();
         }
 
-        auto* sem = sem_.Get<sem::GlobalVariable>(override);
         const_cast<sem::GlobalVariable*>(sem)->SetOverrideId(id);
     }
     return true;
diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc
index dcd6b06..a71bebc 100644
--- a/src/tint/resolver/resolver_test.cc
+++ b/src/tint/resolver/resolver_test.cc
@@ -499,7 +499,7 @@
     auto* ref = TypeOf(a)->As<type::Reference>();
     ASSERT_NE(ref, nullptr);
     auto* ary = ref->StoreType()->As<type::Array>();
-    auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+    auto* sem_override = Sem().Get(override);
     ASSERT_NE(sem_override, nullptr);
     EXPECT_EQ(ary->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
 }
@@ -524,7 +524,7 @@
     ASSERT_NE(ref_b, nullptr);
     auto* ary_b = ref_b->StoreType()->As<type::Array>();
 
-    auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+    auto* sem_override = Sem().Get(override);
     ASSERT_NE(sem_override, nullptr);
     EXPECT_EQ(ary_a->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
     EXPECT_EQ(ary_b->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
@@ -544,7 +544,7 @@
     auto* ref = TypeOf(a)->As<type::Reference>();
     ASSERT_NE(ref, nullptr);
     auto* ary = ref->StoreType()->As<type::Array>();
-    auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+    auto* sem_override = Sem().Get(override);
     ASSERT_NE(sem_override, nullptr);
     EXPECT_EQ(ary->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(cnt)));
 }
@@ -571,7 +571,7 @@
     ASSERT_NE(ref_b, nullptr);
     auto* ary_b = ref_b->StoreType()->As<type::Array>();
 
-    auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
+    auto* sem_override = Sem().Get(override);
     ASSERT_NE(sem_override, nullptr);
     EXPECT_EQ(ary_a->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(a_cnt)));
     EXPECT_EQ(ary_b->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(b_cnt)));
diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h
index d6a5130..ded9e59 100644
--- a/src/tint/resolver/sem_helper.h
+++ b/src/tint/resolver/sem_helper.h
@@ -52,10 +52,13 @@
     /// @returns the resolved symbol (function, type or variable) for the given ast::Identifier or
     /// ast::TypeName cast to the given semantic type.
     /// @param node the node to retrieve
-    template <typename SEM = CastableBase>
-    SEM* ResolvedSymbol(const ast::Node* node) const {
-        auto resolved = dependencies_.resolved_symbols.Find(node);
-        return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
+    template <typename SEM = sem::Info::InferFromAST>
+    sem::Info::GetResultType<SEM, ast::Node>* ResolvedSymbol(const ast::Node* node) const {
+        if (auto resolved = dependencies_.resolved_symbols.Find(node)) {
+            auto* sem = builder_->Sem().Get<SEM>(*resolved);
+            return const_cast<sem::Info::GetResultType<SEM, ast::Node>*>(sem);
+        }
+        return nullptr;
     }
 
     /// @returns the resolved type of the ast::Expression `expr`
diff --git a/src/tint/sem/info.h b/src/tint/sem/info.h
index c97b014..43a158e 100644
--- a/src/tint/sem/info.h
+++ b/src/tint/sem/info.h
@@ -79,6 +79,9 @@
               typename AST = CastableBase,
               typename RESULT = GetResultType<SEM, AST>>
     const RESULT* Get(const AST* ast_node) const {
+        static_assert(std::is_same_v<SEM, InferFromAST> ||
+                          !traits::IsTypeOrDerived<SemanticNodeTypeFor<AST>, SEM>,
+                      "explicit template argument is unnecessary");
         if (ast_node && ast_node->node_id.value < nodes_.size()) {
             return As<RESULT>(nodes_[ast_node->node_id.value]);
         }
diff --git a/src/tint/transform/direct_variable_access.cc b/src/tint/transform/direct_variable_access.cc
index 2d1b514..9d3d37b 100644
--- a/src/tint/transform/direct_variable_access.cc
+++ b/src/tint/transform/direct_variable_access.cc
@@ -990,7 +990,7 @@
                 return nullptr;  // Just clone the expression.
             }
 
-            auto* expr = sem.Get<sem::ValueExpression>(ast_expr);
+            auto* expr = sem.Get(ast_expr);
             if (!expr) {
                 // No semantic node for the expression.
                 return nullptr;  // Just clone the expression.
diff --git a/src/tint/transform/pad_structs.cc b/src/tint/transform/pad_structs.cc
index 9607851..449cde8 100644
--- a/src/tint/transform/pad_structs.cc
+++ b/src/tint/transform/pad_structs.cc
@@ -59,7 +59,7 @@
     utils::Hashset<const ast::StructMember*, 8> padding_members;
 
     ctx.ReplaceAll([&](const ast::Struct* ast_str) -> const ast::Struct* {
-        auto* str = sem.Get<sem::Struct>(ast_str);
+        auto* str = sem.Get(ast_str);
         if (!str || !str->IsHostShareable()) {
             return nullptr;
         }