[ast] Add helpers for searching a decoration list
This is a commonly used pattern.
Change-Id: I698397c93c33db64c53cbe8662186e1976075b80
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47280
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/decoration.h b/src/ast/decoration.h
index b9b5a3c..724a07c 100644
--- a/src/ast/decoration.h
+++ b/src/ast/decoration.h
@@ -36,6 +36,30 @@
/// A list of decorations
using DecorationList = std::vector<Decoration*>;
+/// @param decorations the list of decorations to search
+/// @returns true if `decorations` includes a decoration of type `T`
+template <typename T>
+bool HasDecoration(const DecorationList& decorations) {
+ for (auto* deco : decorations) {
+ if (deco->Is<T>()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/// @param decorations the list of decorations to search
+/// @returns a pointer to `T` from `decorations` if found, otherwise nullptr.
+template <typename T>
+T* GetDecoration(const DecorationList& decorations) {
+ for (auto* deco : decorations) {
+ if (deco->Is<T>()) {
+ return deco->As<T>();
+ }
+ }
+ return nullptr;
+}
+
} // namespace ast
} // namespace tint
diff --git a/src/ast/function.cc b/src/ast/function.cc
index db0cffd..8f8ab73 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -49,19 +49,15 @@
Function::~Function() = default;
std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
- for (auto* deco : decorations_) {
- if (auto* workgroup = deco->As<WorkgroupDecoration>()) {
- return workgroup->values();
- }
+ if (auto* workgroup = GetDecoration<WorkgroupDecoration>(decorations_)) {
+ return workgroup->values();
}
return {1, 1, 1};
}
PipelineStage Function::pipeline_stage() const {
- for (auto* deco : decorations_) {
- if (auto* stage = deco->As<StageDecoration>()) {
- return stage->value();
- }
+ if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
+ return stage->value();
}
return PipelineStage::kNone;
}
diff --git a/src/ast/function.h b/src/ast/function.h
index 1352a71..eb856c3 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -63,18 +63,6 @@
/// @returns the decorations attached to this function
const DecorationList& decorations() const { return decorations_; }
- /// @returns the decoration with the type `T` or nullptr if this function does
- /// not contain a decoration with the given type
- template <typename T>
- const T* find_decoration() const {
- for (auto* deco : decorations()) {
- if (auto* d = deco->As<T>()) {
- return d;
- }
- }
- return nullptr;
- }
-
/// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
/// return if no workgroup size was set.
std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
diff --git a/src/ast/struct.cc b/src/ast/struct.cc
index 7d857b7..e38d350 100644
--- a/src/ast/struct.cc
+++ b/src/ast/struct.cc
@@ -50,12 +50,7 @@
}
bool Struct::IsBlockDecorated() const {
- for (auto* deco : decorations_) {
- if (deco->Is<StructBlockDecoration>()) {
- return true;
- }
- }
- return false;
+ return HasDecoration<StructBlockDecoration>(decorations_);
}
Struct* Struct::Clone(CloneContext* ctx) const {
diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc
index c576ee7..70b803f 100644
--- a/src/ast/struct_member.cc
+++ b/src/ast/struct_member.cc
@@ -41,19 +41,13 @@
StructMember::~StructMember() = default;
bool StructMember::has_offset_decoration() const {
- for (auto* deco : decorations_) {
- if (deco->Is<StructMemberOffsetDecoration>()) {
- return true;
- }
- }
- return false;
+ return HasDecoration<StructMemberOffsetDecoration>(decorations_);
}
uint32_t StructMember::offset() const {
- for (auto* deco : decorations_) {
- if (auto* offset = deco->As<StructMemberOffsetDecoration>()) {
- return offset->offset();
- }
+ if (auto* offset =
+ GetDecoration<StructMemberOffsetDecoration>(decorations_)) {
+ return offset->offset();
}
return 0;
}
diff --git a/src/ast/variable.cc b/src/ast/variable.cc
index 3c7d5a0..24308b5 100644
--- a/src/ast/variable.cc
+++ b/src/ast/variable.cc
@@ -59,49 +59,11 @@
return BindingPoint{group, binding};
}
-bool Variable::HasLocationDecoration() const {
- for (auto* deco : decorations_) {
- if (deco->Is<LocationDecoration>()) {
- return true;
- }
- }
- return false;
-}
-
-bool Variable::HasBuiltinDecoration() const {
- for (auto* deco : decorations_) {
- if (deco->Is<BuiltinDecoration>()) {
- return true;
- }
- }
- return false;
-}
-
-bool Variable::HasConstantIdDecoration() const {
- for (auto* deco : decorations_) {
- if (deco->Is<ConstantIdDecoration>()) {
- return true;
- }
- }
- return false;
-}
-
-LocationDecoration* Variable::GetLocationDecoration() const {
- for (auto* deco : decorations_) {
- if (deco->Is<LocationDecoration>()) {
- return deco->As<LocationDecoration>();
- }
- }
- return nullptr;
-}
-
uint32_t Variable::constant_id() const {
- TINT_ASSERT(HasConstantIdDecoration());
- for (auto* deco : decorations_) {
- if (auto* cid = deco->As<ConstantIdDecoration>()) {
- return cid->value();
- }
+ if (auto* cid = GetDecoration<ConstantIdDecoration>(decorations_)) {
+ return cid->value();
}
+ TINT_ASSERT(false);
return 0;
}
diff --git a/src/ast/variable.h b/src/ast/variable.h
index b8f50e2..a512d01 100644
--- a/src/ast/variable.h
+++ b/src/ast/variable.h
@@ -134,18 +134,8 @@
/// @returns the binding point information for the variable
BindingPoint binding_point() const;
- /// @returns true if the decorations include a LocationDecoration
- bool HasLocationDecoration() const;
- /// @returns true if the decorations include a BuiltinDecoration
- bool HasBuiltinDecoration() const;
- /// @returns true if the decorations include a ConstantIdDecoration
- bool HasConstantIdDecoration() const;
-
- /// @returns pointer to LocationDecoration in decorations, otherwise NULL.
- LocationDecoration* GetLocationDecoration() const;
-
- /// @returns the constant_id value for the variable. Assumes that
- /// HasConstantIdDecoration() has been called first.
+ /// @returns the constant_id value for the variable. Assumes that this
+ /// variable has a constant ID decoration.
uint32_t constant_id() const;
/// Clones this node and all transitive child nodes using the `CloneContext`
diff --git a/src/ast/variable_test.cc b/src/ast/variable_test.cc
index 258f988..929e922 100644
--- a/src/ast/variable_test.cc
+++ b/src/ast/variable_test.cc
@@ -98,11 +98,12 @@
create<ConstantIdDecoration>(1200),
});
- EXPECT_TRUE(var->HasLocationDecoration());
- EXPECT_TRUE(var->HasBuiltinDecoration());
- EXPECT_TRUE(var->HasConstantIdDecoration());
+ auto& decorations = var->decorations();
+ EXPECT_TRUE(ast::HasDecoration<ast::LocationDecoration>(decorations));
+ EXPECT_TRUE(ast::HasDecoration<ast::BuiltinDecoration>(decorations));
+ EXPECT_TRUE(ast::HasDecoration<ast::ConstantIdDecoration>(decorations));
- auto* location = var->GetLocationDecoration();
+ auto* location = ast::GetDecoration<ast::LocationDecoration>(decorations);
ASSERT_NE(nullptr, location);
EXPECT_EQ(1u, location->value());
}
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc
index 8e8770c..8ce0901 100644
--- a/src/inspector/inspector.cc
+++ b/src/inspector/inspector.cc
@@ -18,6 +18,7 @@
#include "src/ast/bool_literal.h"
#include "src/ast/float_literal.h"
+#include "src/ast/constant_id_decoration.h"
#include "src/ast/module.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
@@ -203,7 +204,7 @@
auto* decl = var->Declaration();
auto name = program_->Symbols().NameFor(decl->symbol());
- if (decl->HasBuiltinDecoration()) {
+ if (ast::HasDecoration<ast::BuiltinDecoration>(decl->decorations())) {
continue;
}
@@ -220,7 +221,8 @@
stage_variable.component_type = ComponentType::kSInt;
}
- auto* location_decoration = decl->GetLocationDecoration();
+ auto* location_decoration =
+ ast::GetDecoration<ast::LocationDecoration>(decl->decorations());
if (location_decoration) {
stage_variable.has_location_decoration = true;
stage_variable.location_decoration = location_decoration->value();
@@ -257,7 +259,7 @@
std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
std::map<uint32_t, Scalar> result;
for (auto* var : program_->AST().GlobalVariables()) {
- if (!var->HasConstantIdDecoration()) {
+ if (!ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
continue;
}
diff --git a/src/reader/wgsl/parser_impl_global_constant_decl_test.cc b/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
index 3a72f04..f63a0e0 100644
--- a/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_global_constant_decl_test.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "src/ast/constant_id_decoration.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
namespace tint {
@@ -43,7 +44,8 @@
ASSERT_NE(e->constructor(), nullptr);
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
- EXPECT_FALSE(e.value->HasConstantIdDecoration());
+ EXPECT_FALSE(
+ ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
}
TEST_F(ParserImplTest, GlobalConstantDecl_MissingEqual) {
@@ -123,7 +125,8 @@
ASSERT_NE(e->constructor(), nullptr);
EXPECT_TRUE(e->constructor()->Is<ast::ConstructorExpression>());
- EXPECT_TRUE(e.value->HasConstantIdDecoration());
+ EXPECT_TRUE(
+ ast::HasDecoration<ast::ConstantIdDecoration>(e.value->decorations()));
EXPECT_EQ(e.value->constant_id(), 7u);
}
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index da891d5..c9110a4 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -310,7 +310,8 @@
func->source());
return false;
}
- } else if (!func->find_decoration<ast::InternalDecoration>()) {
+ } else if (!ast::HasDecoration<ast::InternalDecoration>(
+ func->decorations())) {
TINT_ICE(diagnostics_)
<< "Function " << builder_->Symbols().NameFor(func->symbol())
<< " has no body and does not have the [[internal]] decoration";
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 4f82f6c..ed508a7 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1310,8 +1310,9 @@
}
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
- if (var->Declaration()->HasLocationDecoration() ||
- var->Declaration()->HasBuiltinDecoration()) {
+ auto& decorations = var->Declaration()->decorations();
+ if (ast::HasDecoration<ast::LocationDecoration>(decorations) ||
+ ast::HasDecoration<ast::BuiltinDecoration>(decorations)) {
return var->StorageClass() == ast::StorageClass::kInput ||
var->StorageClass() == ast::StorageClass::kOutput;
}
@@ -1463,7 +1464,7 @@
auto* func_sem = builder_.Sem().Get(func);
- if (func->find_decoration<ast::InternalDecoration>()) {
+ if (ast::HasDecoration<ast::InternalDecoration>(func->decorations())) {
// An internal function. Do not emit.
return true;
}
@@ -2825,7 +2826,7 @@
auto* type = builder_.Sem().Get(var)->Type();
- if (var->HasConstantIdDecoration()) {
+ if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
auto const_id = var->constant_id();
out << "#ifndef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index d7e47e1..98a2a85 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -1562,12 +1562,15 @@
}
bool GeneratorImpl::global_is_in_struct(const semantic::Variable* var) const {
+ auto& decorations = var->Declaration()->decorations();
bool in_or_out_struct_has_location =
- var != nullptr && var->Declaration()->HasLocationDecoration() &&
+ var != nullptr &&
+ ast::HasDecoration<ast::LocationDecoration>(decorations) &&
(var->StorageClass() == ast::StorageClass::kInput ||
var->StorageClass() == ast::StorageClass::kOutput);
bool in_struct_has_builtin =
- var != nullptr && var->Declaration()->HasBuiltinDecoration() &&
+ var != nullptr &&
+ ast::HasDecoration<ast::BuiltinDecoration>(decorations) &&
var->StorageClass() == ast::StorageClass::kOutput;
return in_or_out_struct_has_location || in_struct_has_builtin;
}
@@ -2249,7 +2252,7 @@
out_ << " " << program_->Symbols().NameFor(var->symbol());
}
- if (var->HasConstantIdDecoration()) {
+ if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
out_ << " [[function_constant(" << var->constant_id() << ")]]";
} else if (var->constructor() != nullptr) {
out_ << " = ";
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index af2ec2f..005c606 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -750,7 +750,7 @@
// one
// 2- If we don't have a constructor and we're an Output or Private variable
// then WGSL requires an initializer.
- if (var->HasConstantIdDecoration()) {
+ if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
if (type_no_ac->Is<type::F32>()) {
ast::FloatLiteral l(Source{}, type_no_ac, 0.0f);
init_id = GenerateLiteralIfNeeded(var, &l);
@@ -1490,7 +1490,8 @@
ast::Literal* lit) {
ScalarConstant constant;
- if (var && var->HasConstantIdDecoration()) {
+ if (var &&
+ ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
constant.is_spec_op = true;
constant.constant_id = var->constant_id();
}