[ast] Add helpers for searching a decoration list

This is a commonly used pattern.

Change-Id: I698397c93c33db64c53cbe8662186e1976075b80
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47280
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/decoration.h b/src/ast/decoration.h
index b9b5a3c..724a07c 100644
--- a/src/ast/decoration.h
+++ b/src/ast/decoration.h
@@ -36,6 +36,30 @@
 /// A list of decorations
 using DecorationList = std::vector<Decoration*>;
 
+/// @param decorations the list of decorations to search
+/// @returns true if `decorations` includes a decoration of type `T`
+template <typename T>
+bool HasDecoration(const DecorationList& decorations) {
+  for (auto* deco : decorations) {
+    if (deco->Is<T>()) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/// @param decorations the list of decorations to search
+/// @returns a pointer to `T` from `decorations` if found, otherwise nullptr.
+template <typename T>
+T* GetDecoration(const DecorationList& decorations) {
+  for (auto* deco : decorations) {
+    if (deco->Is<T>()) {
+      return deco->As<T>();
+    }
+  }
+  return nullptr;
+}
+
 }  // namespace ast
 }  // namespace tint
 
diff --git a/src/ast/function.cc b/src/ast/function.cc
index db0cffd..8f8ab73 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -49,19 +49,15 @@
 Function::~Function() = default;
 
 std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
-  for (auto* deco : decorations_) {
-    if (auto* workgroup = deco->As<WorkgroupDecoration>()) {
-      return workgroup->values();
-    }
+  if (auto* workgroup = GetDecoration<WorkgroupDecoration>(decorations_)) {
+    return workgroup->values();
   }
   return {1, 1, 1};
 }
 
 PipelineStage Function::pipeline_stage() const {
-  for (auto* deco : decorations_) {
-    if (auto* stage = deco->As<StageDecoration>()) {
-      return stage->value();
-    }
+  if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
+    return stage->value();
   }
   return PipelineStage::kNone;
 }
diff --git a/src/ast/function.h b/src/ast/function.h
index 1352a71..eb856c3 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -63,18 +63,6 @@
   /// @returns the decorations attached to this function
   const DecorationList& decorations() const { return decorations_; }
 
-  /// @returns the decoration with the type `T` or nullptr if this function does
-  /// not contain a decoration with the given type
-  template <typename T>
-  const T* find_decoration() const {
-    for (auto* deco : decorations()) {
-      if (auto* d = deco->As<T>()) {
-        return d;
-      }
-    }
-    return nullptr;
-  }
-
   /// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
   /// return if no workgroup size was set.
   std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
diff --git a/src/ast/struct.cc b/src/ast/struct.cc
index 7d857b7..e38d350 100644
--- a/src/ast/struct.cc
+++ b/src/ast/struct.cc
@@ -50,12 +50,7 @@
 }
 
 bool Struct::IsBlockDecorated() const {
-  for (auto* deco : decorations_) {
-    if (deco->Is<StructBlockDecoration>()) {
-      return true;
-    }
-  }
-  return false;
+  return HasDecoration<StructBlockDecoration>(decorations_);
 }
 
 Struct* Struct::Clone(CloneContext* ctx) const {
diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc
index c576ee7..70b803f 100644
--- a/src/ast/struct_member.cc
+++ b/src/ast/struct_member.cc
@@ -41,19 +41,13 @@
 StructMember::~StructMember() = default;
 
 bool StructMember::has_offset_decoration() const {
-  for (auto* deco : decorations_) {
-    if (deco->Is<StructMemberOffsetDecoration>()) {
-      return true;
-    }
-  }
-  return false;
+  return HasDecoration<StructMemberOffsetDecoration>(decorations_);
 }
 
 uint32_t StructMember::offset() const {
-  for (auto* deco : decorations_) {
-    if (auto* offset = deco->As<StructMemberOffsetDecoration>()) {
-      return offset->offset();
-    }
+  if (auto* offset =
+          GetDecoration<StructMemberOffsetDecoration>(decorations_)) {
+    return offset->offset();
   }
   return 0;
 }
diff --git a/src/ast/variable.cc b/src/ast/variable.cc
index 3c7d5a0..24308b5 100644
--- a/src/ast/variable.cc
+++ b/src/ast/variable.cc
@@ -59,49 +59,11 @@
   return BindingPoint{group, binding};
 }
 
-bool Variable::HasLocationDecoration() const {
-  for (auto* deco : decorations_) {
-    if (deco->Is<LocationDecoration>()) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool Variable::HasBuiltinDecoration() const {
-  for (auto* deco : decorations_) {
-    if (deco->Is<BuiltinDecoration>()) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool Variable::HasConstantIdDecoration() const {
-  for (auto* deco : decorations_) {
-    if (deco->Is<ConstantIdDecoration>()) {
-      return true;
-    }
-  }
-  return false;
-}
-
-LocationDecoration* Variable::GetLocationDecoration() const {
-  for (auto* deco : decorations_) {
-    if (deco->Is<LocationDecoration>()) {
-      return deco->As<LocationDecoration>();
-    }
-  }
-  return nullptr;
-}
-
 uint32_t Variable::constant_id() const {
-  TINT_ASSERT(HasConstantIdDecoration());
-  for (auto* deco : decorations_) {
-    if (auto* cid = deco->As<ConstantIdDecoration>()) {
-      return cid->value();
-    }
+  if (auto* cid = GetDecoration<ConstantIdDecoration>(decorations_)) {
+    return cid->value();
   }
+  TINT_ASSERT(false);
   return 0;
 }
 
diff --git a/src/ast/variable.h b/src/ast/variable.h
index b8f50e2..a512d01 100644
--- a/src/ast/variable.h
+++ b/src/ast/variable.h
@@ -134,18 +134,8 @@
   /// @returns the binding point information for the variable
   BindingPoint binding_point() const;
 
-  /// @returns true if the decorations include a LocationDecoration
-  bool HasLocationDecoration() const;
-  /// @returns true if the decorations include a BuiltinDecoration
-  bool HasBuiltinDecoration() const;
-  /// @returns true if the decorations include a ConstantIdDecoration
-  bool HasConstantIdDecoration() const;
-
-  /// @returns pointer to LocationDecoration in decorations, otherwise NULL.
-  LocationDecoration* GetLocationDecoration() const;
-
-  /// @returns the constant_id value for the variable. Assumes that
-  /// HasConstantIdDecoration() has been called first.
+  /// @returns the constant_id value for the variable. Assumes that this
+  /// variable has a constant ID decoration.
   uint32_t constant_id() const;
 
   /// Clones this node and all transitive child nodes using the `CloneContext`
diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc
index 258f988..929e922 100644
--- a/src/ast/variable_test.cc
+++ b/src/ast/variable_test.cc
@@ -98,11 +98,12 @@
                       create<ConstantIdDecoration>(1200),
                   });
 
-  EXPECT_TRUE(var->HasLocationDecoration());
-  EXPECT_TRUE(var->HasBuiltinDecoration());
-  EXPECT_TRUE(var->HasConstantIdDecoration());
+  auto& decorations = var->decorations();
+  EXPECT_TRUE(ast::HasDecoration<ast::LocationDecoration>(decorations));
+  EXPECT_TRUE(ast::HasDecoration<ast::BuiltinDecoration>(decorations));
+  EXPECT_TRUE(ast::HasDecoration<ast::ConstantIdDecoration>(decorations));
 
-  auto* location = var->GetLocationDecoration();
+  auto* location = ast::GetDecoration<ast::LocationDecoration>(decorations);
   ASSERT_NE(nullptr, location);
   EXPECT_EQ(1u, location->value());
 }
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc
index 8e8770c..8ce0901 100644
--- a/src/inspector/inspector.cc
+++ b/src/inspector/inspector.cc
@@ -18,6 +18,7 @@
 
 #include "src/ast/bool_literal.h"
 #include "src/ast/float_literal.h"
+#include "src/ast/constant_id_decoration.h"
 #include "src/ast/module.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/sint_literal.h"
@@ -203,7 +204,7 @@
       auto* decl = var->Declaration();
 
       auto name = program_->Symbols().NameFor(decl->symbol());
-      if (decl->HasBuiltinDecoration()) {
+      if (ast::HasDecoration<ast::BuiltinDecoration>(decl->decorations())) {
         continue;
       }
 
@@ -220,7 +221,8 @@
         stage_variable.component_type = ComponentType::kSInt;
       }
 
-      auto* location_decoration = decl->GetLocationDecoration();
+      auto* location_decoration =
+          ast::GetDecoration<ast::LocationDecoration>(decl->decorations());
       if (location_decoration) {
         stage_variable.has_location_decoration = true;
         stage_variable.location_decoration = location_decoration->value();
@@ -257,7 +259,7 @@
 std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
   std::map<uint32_t, Scalar> result;
   for (auto* var : program_->AST().GlobalVariables()) {
-    if (!var->HasConstantIdDecoration()) {
+    if (!ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
       continue;
     }
 
diff --git a/src/reader/wgsl/parser_impl_global_constant_decl_test.cc b/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
index 3a72f04..f63a0e0 100644
--- a/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include "src/ast/constant_id_decoration.h"
 #include "src/reader/wgsl/parser_impl_test_helper.h"
 
 namespace tint {
@@ -43,7 +44,8 @@
   ASSERT_NE(e->constructor(), nullptr);
   EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
 
-  EXPECT_FALSE(e.value->HasConstantIdDecoration());
+  EXPECT_FALSE(
+      ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
 }
 
 TEST_F(ParserImplTest, GlobalConstantDecl_MissingEqual) {
@@ -123,7 +125,8 @@
   ASSERT_NE(e->constructor(), nullptr);
   EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
 
-  EXPECT_TRUE(e.value->HasConstantIdDecoration());
+  EXPECT_TRUE(
+      ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
   EXPECT_EQ(e.value->constant_id(), 7u);
 }
 
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index da891d5..c9110a4 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -310,7 +310,8 @@
             func->source());
         return false;
       }
-    } else if (!func->find_decoration<ast::InternalDecoration>()) {
+    } else if (!ast::HasDecoration<ast::InternalDecoration>(
+                   func->decorations())) {
       TINT_ICE(diagnostics_)
           << "Function " << builder_->Symbols().NameFor(func->symbol())
           << " has no body and does not have the [[internal]] decoration";
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 4f82f6c..ed508a7 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1310,8 +1310,9 @@
 }
 
 bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
-  if (var->Declaration()->HasLocationDecoration() ||
-      var->Declaration()->HasBuiltinDecoration()) {
+  auto& decorations = var->Declaration()->decorations();
+  if (ast::HasDecoration<ast::LocationDecoration>(decorations) ||
+      ast::HasDecoration<ast::BuiltinDecoration>(decorations)) {
     return var->StorageClass() == ast::StorageClass::kInput ||
            var->StorageClass() == ast::StorageClass::kOutput;
   }
@@ -1463,7 +1464,7 @@
 
   auto* func_sem = builder_.Sem().Get(func);
 
-  if (func->find_decoration<ast::InternalDecoration>()) {
+  if (ast::HasDecoration<ast::InternalDecoration>(func->decorations())) {
     // An internal function. Do not emit.
     return true;
   }
@@ -2825,7 +2826,7 @@
 
   auto* type = builder_.Sem().Get(var)->Type();
 
-  if (var->HasConstantIdDecoration()) {
+  if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
     auto const_id = var->constant_id();
 
     out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index d7e47e1..98a2a85 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -1562,12 +1562,15 @@
 }
 
 bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
+  auto& decorations = var->Declaration()->decorations();
   bool in_or_out_struct_has_location =
-      var != nullptr && var->Declaration()->HasLocationDecoration() &&
+      var != nullptr &&
+      ast::HasDecoration<ast::LocationDecoration>(decorations) &&
       (var->StorageClass() == ast::StorageClass::kInput ||
        var->StorageClass() == ast::StorageClass::kOutput);
   bool in_struct_has_builtin =
-      var != nullptr && var->Declaration()->HasBuiltinDecoration() &&
+      var != nullptr &&
+      ast::HasDecoration<ast::BuiltinDecoration>(decorations) &&
       var->StorageClass() == ast::StorageClass::kOutput;
   return in_or_out_struct_has_location || in_struct_has_builtin;
 }
@@ -2249,7 +2252,7 @@
     out_ << " " << program_->Symbols().NameFor(var->symbol());
   }
 
-  if (var->HasConstantIdDecoration()) {
+  if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
     out_ << " [[function_constant(" << var->constant_id() << ")]]";
   } else if (var->constructor() != nullptr) {
     out_ << " = ";
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index af2ec2f..005c606 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -750,7 +750,7 @@
     //    one
     // 2- If we don't have a constructor and we're an Output or Private variable
     //    then WGSL requires an initializer.
-    if (var->HasConstantIdDecoration()) {
+    if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
       if (type_no_ac->Is<type::F32>()) {
         ast::FloatLiteral l(Source{}, type_no_ac, 0.0f);
         init_id = GenerateLiteralIfNeeded(var, &l);
@@ -1490,7 +1490,8 @@
                                           ast::Literal* lit) {
   ScalarConstant constant;
 
-  if (var && var->HasConstantIdDecoration()) {
+  if (var &&
+      ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
     constant.is_spec_op = true;
     constant.constant_id = var->constant_id();
   }