tint/resolver: Clean up attribute resolving
Attributes resolving was done ad-hoc throughout the resolver, with the
validator ensuring that attributes were only applied to the correct nodes.
The ad-hoc nature meant that attributes were inconsistently marked and
resolved, and the attribute arguments were not always validated
(especially when used internally).
This change inlines the attribute processing into the appropriate places
in the resolver, and uses a standardized error message for attributes
that cannot be applied.
Change-Id: Ic084820949bbf8276fb2d33c103fa29b77824a69
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/129620
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index 098ab17..b9160ce 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -131,6 +131,43 @@
return {};
}
+static std::string name(AttributeKind kind) {
+ switch (kind) {
+ case AttributeKind::kAlign:
+ return "@align";
+ case AttributeKind::kBinding:
+ return "@binding";
+ case AttributeKind::kBuiltin:
+ return "@builtin";
+ case AttributeKind::kDiagnostic:
+ return "@diagnostic";
+ case AttributeKind::kGroup:
+ return "@group";
+ case AttributeKind::kId:
+ return "@id";
+ case AttributeKind::kInterpolate:
+ return "@interpolate";
+ case AttributeKind::kInvariant:
+ return "@invariant";
+ case AttributeKind::kLocation:
+ return "@location";
+ case AttributeKind::kOffset:
+ return "@offset";
+ case AttributeKind::kMustUse:
+ return "@must_use";
+ case AttributeKind::kSize:
+ return "@size";
+ case AttributeKind::kStage:
+ return "@stage";
+ case AttributeKind::kStride:
+ return "@stride";
+ case AttributeKind::kWorkgroup:
+ return "@workgroup_size";
+ case AttributeKind::kBindingAndGroup:
+ return "@binding";
+ }
+ return "<unknown>";
+}
namespace FunctionInputAndOutputTests {
using FunctionParameterAttributeTest = TestWithParams;
TEST_P(FunctionParameterAttributeTest, IsValid) {
@@ -144,11 +181,16 @@
if (params.should_pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else if (params.kind == AttributeKind::kLocation || params.kind == AttributeKind::kBuiltin ||
+ params.kind == AttributeKind::kInvariant ||
+ params.kind == AttributeKind::kInterpolate) {
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "error: " + name(params.kind) +
+ " is not valid for non-entry point function parameters");
} else {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "error: attribute is not valid for non-entry point function "
- "parameters");
+ "error: " + name(params.kind) + " is not valid for function parameters");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -184,9 +226,9 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(),
- "error: attribute is not valid for non-entry point function "
- "return types");
+ EXPECT_EQ(r()->error(), "error: " + name(params.kind) +
+ " is not valid for non-entry point function "
+ "return types");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -234,10 +276,11 @@
} else if (params.kind == AttributeKind::kInterpolate ||
params.kind == AttributeKind::kLocation ||
params.kind == AttributeKind::kInvariant) {
- EXPECT_EQ(r()->error(),
- "12:34 error: attribute is not valid for compute shader inputs");
+ EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
+ " is not valid for compute shader inputs");
} else {
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for function parameters");
+ EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
+ " is not valid for function parameters");
}
}
}
@@ -277,7 +320,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for function parameters");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for function parameters");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -331,7 +375,8 @@
"12:34 error: invariant attribute must only be applied to a "
"position builtin");
} else {
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for function parameters");
+ EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
+ " is not valid for function parameters");
}
}
}
@@ -378,12 +423,12 @@
} else if (params.kind == AttributeKind::kInterpolate ||
params.kind == AttributeKind::kLocation ||
params.kind == AttributeKind::kInvariant) {
- EXPECT_EQ(r()->error(),
- "12:34 error: attribute is not valid for compute shader output");
+ EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
+ " is not valid for compute shader output");
} else {
- EXPECT_EQ(r()->error(),
- "12:34 error: attribute is not valid for entry point return "
- "types");
+ EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
+ " is not valid for entry point return "
+ "types");
}
}
}
@@ -434,8 +479,8 @@
R"(34:56 error: duplicate location attribute
12:34 note: first attribute declared here)");
} else {
- EXPECT_EQ(r()->error(),
- R"(12:34 error: attribute is not valid for entry point return types)");
+ EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
+ " is not valid for entry point return types");
}
}
}
@@ -484,8 +529,8 @@
R"(34:56 error: multiple entry point IO attributes
12:34 note: previously consumed @location)");
} else {
- EXPECT_EQ(r()->error(),
- R"(12:34 error: attribute is not valid for entry point return types)");
+ EXPECT_EQ(r()->error(), "12:34 error: " + name(params.kind) +
+ " is not valid for entry point return types");
}
}
}
@@ -591,7 +636,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for struct declarations");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for struct declarations");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -628,7 +674,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for structure members");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for struct members");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -871,7 +918,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for array types");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for array types");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -898,7 +946,6 @@
auto& params = GetParam();
auto attrs = createAttributes(Source{{12, 34}}, *this, params.kind);
- auto* attr = attrs[0];
if (IsBindingAttribute(params.kind)) {
GlobalVar("a", ty.sampler(type::SamplerKind::kSampler), attrs);
} else {
@@ -910,8 +957,8 @@
} else {
EXPECT_FALSE(r()->Resolve());
if (!IsBindingAttribute(params.kind)) {
- EXPECT_EQ(r()->error(), "12:34 error: attribute '" + attr->Name() +
- "' is not valid for module-scope 'var'");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for module-scope 'var'");
}
}
}
@@ -944,13 +991,22 @@
12:34 note: first attribute declared here)");
}
-TEST_F(VariableAttributeTest, LocalVariable) {
+TEST_F(VariableAttributeTest, LocalVar) {
auto* v = Var("a", ty.f32(), utils::Vector{Binding(Source{{12, 34}}, 2_a)});
WrapInFunction(v);
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attributes are not valid on local variables");
+ EXPECT_EQ(r()->error(), "12:34 error: @binding is not valid for function-scope 'var'");
+}
+
+TEST_F(VariableAttributeTest, LocalLet) {
+ auto* v = Let("a", utils::Vector{Binding(Source{{12, 34}}, 2_a)}, Expr(1_a));
+
+ WrapInFunction(v);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: @binding is not valid for 'let' declaration");
}
using ConstantAttributeTest = TestWithParams;
@@ -965,7 +1021,7 @@
} else {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: attribute is not valid for module-scope 'const' declaration");
+ "12:34 error: " + name(params.kind) + " is not valid for 'const' declaration");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -987,17 +1043,14 @@
TestParams{AttributeKind::kWorkgroup, false},
TestParams{AttributeKind::kBindingAndGroup, false}));
-TEST_F(ConstantAttributeTest, DuplicateAttribute) {
+TEST_F(ConstantAttributeTest, InvalidAttribute) {
GlobalConst("a", ty.f32(), Expr(1.23_f),
utils::Vector{
Id(Source{{12, 34}}, 0_a),
- Id(Source{{56, 78}}, 1_a),
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(),
- R"(56:78 error: duplicate id attribute
-12:34 note: first attribute declared here)");
+ EXPECT_EQ(r()->error(), "12:34 error: @id is not valid for 'const' declaration");
}
using OverrideAttributeTest = TestWithParams;
@@ -1010,7 +1063,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for 'override' declaration");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for 'override' declaration");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -1056,7 +1110,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for switch statements");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for switch statements");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -1089,7 +1144,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for switch body");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for switch body");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -1122,7 +1178,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for if statements");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for if statements");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -1155,7 +1212,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for for statements");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for for statements");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -1188,7 +1246,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for loop statements");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for loop statements");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -1221,7 +1280,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: attribute is not valid for while statements");
+ EXPECT_EQ(r()->error(),
+ "12:34 error: " + name(params.kind) + " is not valid for while statements");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
@@ -1251,7 +1311,8 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "error: attribute is not valid for block statements");
+ EXPECT_EQ(r()->error(),
+ "error: " + name(GetParam().kind) + " is not valid for block statements");
}
}
};
diff --git a/src/tint/resolver/entry_point_validation_test.cc b/src/tint/resolver/entry_point_validation_test.cc
index dac6cd2..eb2d779 100644
--- a/src/tint/resolver/entry_point_validation_test.cc
+++ b/src/tint/resolver/entry_point_validation_test.cc
@@ -1084,7 +1084,7 @@
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(12:34 error: attribute is not valid for compute shader output)");
+ EXPECT_EQ(r()->error(), R"(12:34 error: @location is not valid for compute shader output)");
}
TEST_F(LocationAttributeTests, ComputeShaderLocation_Output) {
@@ -1099,7 +1099,7 @@
});
EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), R"(12:34 error: attribute is not valid for compute shader inputs)");
+ EXPECT_EQ(r()->error(), R"(12:34 error: @location is not valid for compute shader inputs)");
}
TEST_F(LocationAttributeTests, ComputeShaderLocationStructMember_Output) {
@@ -1119,7 +1119,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: attribute is not valid for compute shader output\n"
+ "12:34 error: @location is not valid for compute shader output\n"
"56:78 note: while analyzing entry point 'main'");
}
@@ -1138,7 +1138,7 @@
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
- "12:34 error: attribute is not valid for compute shader inputs\n"
+ "12:34 error: @location is not valid for compute shader inputs\n"
"56:78 note: while analyzing entry point 'main'");
}
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 99e68d0..638cac0 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -247,6 +247,20 @@
}
}
+ for (auto* attribute : v->attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute, //
+ [&](const ast::InternalAttribute* attr) -> bool { return InternalAttribute(attr); },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "'let' declaration");
+ return false;
+ });
+ if (!ok) {
+ return nullptr;
+ }
+ }
+
if (!v->initializer) {
AddError("'let' declaration must have an initializer", v->source);
return nullptr;
@@ -340,37 +354,51 @@
/* constant_value */ nullptr, std::nullopt, std::nullopt);
sem->SetInitializer(rhs);
- if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(v->attributes)) {
- ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@id"};
- TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+ for (auto* attribute : v->attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute, //
+ [&](const ast::IdAttribute* attr) {
+ ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@id"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- auto* materialized = Materialize(ValueExpression(id_attr->expr));
- if (!materialized) {
+ auto* materialized = Materialize(ValueExpression(attr->expr));
+ if (!materialized) {
+ return false;
+ }
+ if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
+ AddError("@id must be an i32 or u32 value", attr->source);
+ return false;
+ }
+
+ auto const_value = materialized->ConstantValue();
+ auto value = const_value->ValueAs<AInt>();
+ if (value < 0) {
+ AddError("@id value must be non-negative", attr->source);
+ return false;
+ }
+ if (value > std::numeric_limits<decltype(OverrideId::value)>::max()) {
+ AddError(
+ "@id value must be between 0 and " +
+ std::to_string(std::numeric_limits<decltype(OverrideId::value)>::max()),
+ attr->source);
+ return false;
+ }
+
+ auto o = OverrideId{static_cast<decltype(OverrideId::value)>(value)};
+ sem->SetOverrideId(o);
+
+ // Track the constant IDs that are specified in the shader.
+ override_ids_.Add(o, sem);
+ return true;
+ },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "'override' declaration");
+ return false;
+ });
+ if (!ok) {
return nullptr;
}
- if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
- AddError("@id must be an i32 or u32 value", id_attr->source);
- return nullptr;
- }
-
- auto const_value = materialized->ConstantValue();
- auto value = const_value->ValueAs<AInt>();
- if (value < 0) {
- AddError("@id value must be non-negative", id_attr->source);
- return nullptr;
- }
- if (value > std::numeric_limits<decltype(OverrideId::value)>::max()) {
- AddError("@id value must be between 0 and " +
- std::to_string(std::numeric_limits<decltype(OverrideId::value)>::max()),
- id_attr->source);
- return nullptr;
- }
-
- auto o = OverrideId{static_cast<decltype(OverrideId::value)>(value)};
- sem->SetOverrideId(o);
-
- // Track the constant IDs that are specified in the shader.
- override_ids_.Add(o, sem);
}
builder_->Sem().Add(v, sem);
@@ -393,6 +421,18 @@
return nullptr;
}
+ for (auto* attribute : c->attributes) {
+ Mark(attribute);
+ bool ok = Switch(attribute, //
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "'const' declaration");
+ return false;
+ });
+ if (!ok) {
+ return nullptr;
+ }
+ }
+
const sem::ValueExpression* rhs = nullptr;
{
ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "const initializer"};
@@ -529,72 +569,98 @@
sem::Variable* sem = nullptr;
if (is_global) {
+ bool has_io_address_space = address_space == builtin::AddressSpace::kIn ||
+ address_space == builtin::AddressSpace::kOut;
+
+ std::optional<uint32_t> group, binding, location;
+ for (auto* attribute : var->attributes) {
+ Mark(attribute);
+ enum Status { kSuccess, kErrored, kInvalid };
+ auto res = Switch(
+ attribute, //
+ [&](const ast::BindingAttribute* attr) {
+ auto value = BindingAttribute(attr);
+ if (!value) {
+ return kErrored;
+ }
+ binding = value.Get();
+ return kSuccess;
+ },
+ [&](const ast::GroupAttribute* attr) {
+ auto value = GroupAttribute(attr);
+ if (!value) {
+ return kErrored;
+ }
+ group = value.Get();
+ return kSuccess;
+ },
+ [&](const ast::LocationAttribute* attr) {
+ if (!has_io_address_space) {
+ return kInvalid;
+ }
+ auto value = LocationAttribute(attr);
+ if (!value) {
+ return kErrored;
+ }
+ location = value.Get();
+ return kSuccess;
+ },
+ [&](const ast::BuiltinAttribute* attr) {
+ if (!has_io_address_space) {
+ return kInvalid;
+ }
+ return BuiltinAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](const ast::InterpolateAttribute* attr) {
+ if (!has_io_address_space) {
+ return kInvalid;
+ }
+ return InterpolateAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](const ast::InvariantAttribute* attr) {
+ if (!has_io_address_space) {
+ return kInvalid;
+ }
+ return InvariantAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](const ast::InternalAttribute* attr) {
+ return InternalAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](Default) { return kInvalid; });
+
+ switch (res) {
+ case kSuccess:
+ break;
+ case kErrored:
+ return nullptr;
+ case kInvalid:
+ ErrorInvalidAttribute(attribute, "module-scope 'var'");
+ return nullptr;
+ }
+ }
+
std::optional<sem::BindingPoint> binding_point;
- if (var->HasBindingPoint()) {
- uint32_t binding = 0;
- {
- ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@binding"};
- TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
-
- auto* attr = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
- auto* materialized = Materialize(ValueExpression(attr->expr));
- if (!materialized) {
- return nullptr;
- }
- if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
- AddError("@binding must be an i32 or u32 value", attr->source);
- return nullptr;
- }
-
- auto const_value = materialized->ConstantValue();
- auto value = const_value->ValueAs<AInt>();
- if (value < 0) {
- AddError("@binding value must be non-negative", attr->source);
- return nullptr;
- }
- binding = u32(value);
- }
-
- uint32_t group = 0;
- {
- ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@group"};
- TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
-
- auto* attr = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
- auto* materialized = Materialize(ValueExpression(attr->expr));
- if (!materialized) {
- return nullptr;
- }
- if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
- AddError("@group must be an i32 or u32 value", attr->source);
- return nullptr;
- }
-
- auto const_value = materialized->ConstantValue();
- auto value = const_value->ValueAs<AInt>();
- if (value < 0) {
- AddError("@group value must be non-negative", attr->source);
- return nullptr;
- }
- group = u32(value);
- }
- binding_point = {group, binding};
+ if (group && binding) {
+ binding_point = sem::BindingPoint{group.value(), binding.value()};
}
-
- std::optional<uint32_t> location;
- if (auto* attr = ast::GetAttribute<ast::LocationAttribute>(var->attributes)) {
- auto value = LocationAttribute(attr);
- if (!value) {
- return nullptr;
- }
- location = value.Get();
- }
-
sem = builder_->create<sem::GlobalVariable>(
var, var_ty, sem::EvaluationStage::kRuntime, address_space, access,
/* constant_value */ nullptr, binding_point, location);
} else {
+ for (auto* attribute : var->attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute,
+ [&](const ast::InternalAttribute* attr) { return InternalAttribute(attr); },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "function-scope 'var'");
+ return false;
+ });
+ if (!ok) {
+ return nullptr;
+ }
+ }
sem = builder_->create<sem::LocalVariable>(var, var_ty, sem::EvaluationStage::kRuntime,
address_space, access, current_statement_,
/* constant_value */ nullptr);
@@ -605,18 +671,93 @@
return sem;
}
-sem::Parameter* Resolver::Parameter(const ast::Parameter* param, uint32_t index) {
+sem::Parameter* Resolver::Parameter(const ast::Parameter* param,
+ const ast::Function* func,
+ uint32_t index) {
Mark(param->name);
auto add_note = [&] {
AddNote("while instantiating parameter " + param->name->symbol.Name(), param->source);
};
- for (auto* attr : param->attributes) {
- if (!Attribute(attr)) {
- return nullptr;
+ std::optional<uint32_t> location, group, binding;
+
+ if (func->IsEntryPoint()) {
+ for (auto* attribute : param->attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute, //
+ [&](const ast::LocationAttribute* attr) {
+ auto value = LocationAttribute(attr);
+ if (!value) {
+ return false;
+ }
+ location = value.Get();
+ return true;
+ },
+ [&](const ast::BuiltinAttribute* attr) -> bool { return BuiltinAttribute(attr); },
+ [&](const ast::InvariantAttribute* attr) -> bool {
+ return InvariantAttribute(attr);
+ },
+ [&](const ast::InterpolateAttribute* attr) -> bool {
+ return InterpolateAttribute(attr);
+ },
+ [&](const ast::InternalAttribute* attr) -> bool { return InternalAttribute(attr); },
+ [&](const ast::GroupAttribute* attr) -> bool {
+ if (validator_.IsValidationEnabled(
+ param->attributes, ast::DisabledValidation::kEntryPointParameter)) {
+ ErrorInvalidAttribute(attribute, "function parameters");
+ return false;
+ }
+ auto value = GroupAttribute(attr);
+ if (!value) {
+ return false;
+ }
+ group = value.Get();
+ return true;
+ },
+ [&](const ast::BindingAttribute* attr) -> bool {
+ if (validator_.IsValidationEnabled(
+ param->attributes, ast::DisabledValidation::kEntryPointParameter)) {
+ ErrorInvalidAttribute(attribute, "function parameters");
+ return false;
+ }
+ auto value = BindingAttribute(attr);
+ if (!value) {
+ return false;
+ }
+ binding = value.Get();
+ return true;
+ },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "function parameters");
+ return false;
+ });
+ if (!ok) {
+ return nullptr;
+ }
+ }
+ } else {
+ for (auto* attribute : param->attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute, //
+ [&](const ast::InternalAttribute* attr) -> bool { return InternalAttribute(attr); },
+ [&](Default) {
+ if (attribute->IsAnyOf<ast::LocationAttribute, ast::BuiltinAttribute,
+ ast::InvariantAttribute, ast::InterpolateAttribute>()) {
+ ErrorInvalidAttribute(attribute, "non-entry point function parameters");
+ } else {
+ ErrorInvalidAttribute(attribute, "function parameters");
+ }
+ return false;
+ });
+ if (!ok) {
+ return nullptr;
+ }
}
}
+
if (!validator_.NoDuplicateAttributes(param->attributes)) {
return nullptr;
}
@@ -642,72 +783,22 @@
}
std::optional<sem::BindingPoint> binding_point;
- if (param->HasBindingPoint()) {
- binding_point = sem::BindingPoint{};
- {
- ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@binding value"};
- TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
-
- auto* attr = ast::GetAttribute<ast::BindingAttribute>(param->attributes);
- auto* materialized = Materialize(ValueExpression(attr->expr));
- if (!materialized) {
- return nullptr;
- }
- binding_point->binding = materialized->ConstantValue()->ValueAs<u32>();
- }
- {
- ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@group value"};
- TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
-
- auto* attr = ast::GetAttribute<ast::GroupAttribute>(param->attributes);
- auto* materialized = Materialize(ValueExpression(attr->expr));
- if (!materialized) {
- return nullptr;
- }
- binding_point->group = materialized->ConstantValue()->ValueAs<u32>();
- }
- }
-
- std::optional<uint32_t> location;
- if (auto* attr = ast::GetAttribute<ast::LocationAttribute>(param->attributes)) {
- auto value = LocationAttribute(attr);
- if (!value) {
- return nullptr;
- }
- location = value.Get();
+ if (group && binding) {
+ binding_point = sem::BindingPoint{group.value(), binding.value()};
}
auto* sem = builder_->create<sem::Parameter>(
param, index, ty, builtin::AddressSpace::kUndefined, builtin::Access::kUndefined,
sem::ParameterUsage::kNone, binding_point, location);
builder_->Sem().Add(param, sem);
+
+ if (!validator_.Parameter(sem)) {
+ return nullptr;
+ }
+
return sem;
}
-utils::Result<uint32_t> Resolver::LocationAttribute(const ast::LocationAttribute* attr) {
- ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@location value"};
- TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
-
- auto* materialized = Materialize(ValueExpression(attr->expr));
- if (!materialized) {
- return utils::Failure;
- }
-
- if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
- AddError("@location must be an i32 or u32 value", attr->source);
- return utils::Failure;
- }
-
- auto const_value = materialized->ConstantValue();
- auto value = const_value->ValueAs<AInt>();
- if (value < 0) {
- AddError("@location value must be non-negative", attr->source);
- return utils::Failure;
- }
-
- return static_cast<uint32_t>(value);
-}
-
builtin::Access Resolver::DefaultAccessForAddressSpace(builtin::AddressSpace address_space) {
// https://gpuweb.github.io/gpuweb/wgsl/#storage-class
switch (address_space) {
@@ -796,12 +887,6 @@
return nullptr;
}
- for (auto* attr : v->attributes) {
- if (!Attribute(attr)) {
- return nullptr;
- }
- }
-
if (!validator_.NoDuplicateAttributes(v->attributes)) {
return nullptr;
}
@@ -860,8 +945,28 @@
validator_.DiagnosticFilters().Push();
TINT_DEFER(validator_.DiagnosticFilters().Pop());
- for (auto* attr : decl->attributes) {
- if (!Attribute(attr)) {
+
+ for (auto* attribute : decl->attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute,
+ [&](const ast::DiagnosticAttribute* attr) { return DiagnosticAttribute(attr); },
+ [&](const ast::StageAttribute* attr) { return StageAttribute(attr); },
+ [&](const ast::MustUseAttribute* attr) { return MustUseAttribute(attr); },
+ [&](const ast::WorkgroupAttribute* attr) {
+ auto value = WorkgroupAttribute(attr);
+ if (!value) {
+ return false;
+ }
+ func->SetWorkgroupSize(value.Get());
+ return true;
+ },
+ [&](const ast::InternalAttribute* attr) { return InternalAttribute(attr); },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "functions");
+ return false;
+ });
+ if (!ok) {
return nullptr;
}
}
@@ -884,15 +989,11 @@
}
}
- auto* p = Parameter(param, parameter_index++);
+ auto* p = Parameter(param, decl, parameter_index++);
if (!p) {
return nullptr;
}
- if (!validator_.Parameter(decl, p)) {
- return nullptr;
- }
-
func->AddParameter(p);
auto* p_ty = const_cast<type::Type*>(p->Type());
@@ -925,18 +1026,73 @@
}
func->SetReturnType(return_type);
- // Determine if the return type has a location
- for (auto* attr : decl->return_type_attributes) {
- if (!Attribute(attr)) {
- return nullptr;
- }
+ if (decl->IsEntryPoint()) {
+ // Determine if the return type has a location
+ bool permissive = validator_.IsValidationDisabled(
+ decl->attributes, ast::DisabledValidation::kEntryPointParameter) ||
+ validator_.IsValidationDisabled(
+ decl->attributes, ast::DisabledValidation::kFunctionParameter);
+ for (auto* attribute : decl->return_type_attributes) {
+ Mark(attribute);
+ enum Status { kSuccess, kErrored, kInvalid };
+ auto res = Switch(
+ attribute, //
+ [&](const ast::LocationAttribute* attr) {
+ auto value = LocationAttribute(attr);
+ if (!value) {
+ return kErrored;
+ }
+ func->SetReturnLocation(value.Get());
+ return kSuccess;
+ },
+ [&](const ast::BuiltinAttribute* attr) {
+ return BuiltinAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](const ast::InternalAttribute* attr) {
+ return InternalAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](const ast::InterpolateAttribute* attr) {
+ return InterpolateAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](const ast::InvariantAttribute* attr) {
+ return InvariantAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](const ast::BindingAttribute* attr) {
+ if (!permissive) {
+ return kInvalid;
+ }
+ return BindingAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](const ast::GroupAttribute* attr) {
+ if (!permissive) {
+ return kInvalid;
+ }
+ return GroupAttribute(attr) ? kSuccess : kErrored;
+ },
+ [&](Default) { return kInvalid; });
- if (auto* loc_attr = attr->As<ast::LocationAttribute>()) {
- auto value = LocationAttribute(loc_attr);
- if (!value) {
+ switch (res) {
+ case kSuccess:
+ break;
+ case kErrored:
+ return nullptr;
+ case kInvalid:
+ ErrorInvalidAttribute(attribute, "entry point return types");
+ return nullptr;
+ }
+ }
+ } else {
+ for (auto* attribute : decl->return_type_attributes) {
+ Mark(attribute);
+ bool ok = Switch(attribute, //
+ [&](Default) {
+ ErrorInvalidAttribute(attribute,
+ "non-entry point function return types");
+ return false;
+ });
+ if (!ok) {
return nullptr;
}
- func->SetReturnLocation(value.Get());
}
}
@@ -964,10 +1120,6 @@
ApplyDiagnosticSeverities(func);
- if (!WorkgroupSize(decl)) {
- return nullptr;
- }
-
if (decl->IsEntryPoint()) {
entry_points_.Push(func);
}
@@ -1016,94 +1168,6 @@
return func;
}
-bool Resolver::WorkgroupSize(const ast::Function* func) {
- // Set work-group size defaults.
- sem::WorkgroupSize ws;
- for (size_t i = 0; i < 3; i++) {
- ws[i] = 1;
- }
-
- auto* attr = ast::GetAttribute<ast::WorkgroupAttribute>(func->attributes);
- if (!attr) {
- return true;
- }
-
- auto values = attr->Values();
- utils::Vector<const sem::ValueExpression*, 3> args;
- utils::Vector<const type::Type*, 3> arg_tys;
-
- constexpr const char* kErrBadExpr =
- "workgroup_size argument must be a constant or override-expression of type "
- "abstract-integer, i32 or u32";
-
- for (size_t i = 0; i < 3; i++) {
- // Each argument to this attribute can either be a literal, an identifier for a
- // module-scope constants, a const-expression, or nullptr if not specified.
- auto* value = values[i];
- if (!value) {
- break;
- }
- const auto* expr = ValueExpression(value);
- if (!expr) {
- return false;
- }
- auto* ty = expr->Type();
- if (!ty->IsAnyOf<type::I32, type::U32, type::AbstractInt>()) {
- AddError(kErrBadExpr, value->source);
- return false;
- }
-
- if (expr->Stage() != sem::EvaluationStage::kConstant &&
- expr->Stage() != sem::EvaluationStage::kOverride) {
- AddError(kErrBadExpr, value->source);
- return false;
- }
-
- args.Push(expr);
- arg_tys.Push(ty);
- }
-
- auto* common_ty = type::Type::Common(arg_tys);
- if (!common_ty) {
- AddError("workgroup_size arguments must be of the same type, either i32 or u32",
- attr->source);
- return false;
- }
-
- // If all arguments are abstract-integers, then materialize to i32.
- if (common_ty->Is<type::AbstractInt>()) {
- common_ty = builder_->create<type::I32>();
- }
-
- for (size_t i = 0; i < args.Length(); i++) {
- auto* materialized = Materialize(args[i], common_ty);
- if (!materialized) {
- return false;
- }
- if (auto* value = materialized->ConstantValue()) {
- if (value->ValueAs<AInt>() < 1) {
- AddError("workgroup_size argument must be at least 1", values[i]->source);
- return false;
- }
- ws[i] = value->ValueAs<u32>();
- } else {
- ws[i] = std::nullopt;
- }
- }
-
- uint64_t total_size = static_cast<uint64_t>(ws[0].value_or(1));
- for (size_t i = 1; i < 3; i++) {
- total_size *= static_cast<uint64_t>(ws[i].value_or(1));
- if (total_size > 0xffffffff) {
- AddError("total workgroup grid size cannot exceed 0xffffffff", values[i]->source);
- return false;
- }
- }
-
- current_function_->SetWorkgroupSize(std::move(ws));
- return true;
-}
-
bool Resolver::Statements(utils::VectorRef<const ast::Statement*> stmts) {
sem::Behaviors behaviors{sem::Behavior::kNext};
@@ -3474,25 +3538,186 @@
return sem;
}
-bool Resolver::Attribute(const ast::Attribute* attr) {
- Mark(attr);
- return Switch(
- attr, //
- [&](const ast::BuiltinAttribute* b) { return BuiltinAttribute(b); },
- [&](const ast::DiagnosticAttribute* d) { return DiagnosticControl(d->control); },
- [&](const ast::InterpolateAttribute* i) { return InterpolateAttribute(i); },
- [&](const ast::InternalAttribute* i) { return InternalAttribute(i); },
- [&](Default) { return true; });
+utils::Result<uint32_t> Resolver::LocationAttribute(const ast::LocationAttribute* attr) {
+ ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@location value"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+
+ auto* materialized = Materialize(ValueExpression(attr->expr));
+ if (!materialized) {
+ return utils::Failure;
+ }
+
+ if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
+ AddError("@location must be an i32 or u32 value", attr->source);
+ return utils::Failure;
+ }
+
+ auto const_value = materialized->ConstantValue();
+ auto value = const_value->ValueAs<AInt>();
+ if (value < 0) {
+ AddError("@location value must be non-negative", attr->source);
+ return utils::Failure;
+ }
+
+ return static_cast<uint32_t>(value);
}
-bool Resolver::BuiltinAttribute(const ast::BuiltinAttribute* attr) {
+utils::Result<uint32_t> Resolver::BindingAttribute(const ast::BindingAttribute* attr) {
+ ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@binding"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+
+ auto* materialized = Materialize(ValueExpression(attr->expr));
+ if (!materialized) {
+ return utils::Failure;
+ }
+ if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
+ AddError("@binding must be an i32 or u32 value", attr->source);
+ return utils::Failure;
+ }
+
+ auto const_value = materialized->ConstantValue();
+ auto value = const_value->ValueAs<AInt>();
+ if (value < 0) {
+ AddError("@binding value must be non-negative", attr->source);
+ return utils::Failure;
+ }
+ return static_cast<uint32_t>(value);
+}
+
+utils::Result<uint32_t> Resolver::GroupAttribute(const ast::GroupAttribute* attr) {
+ ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@group"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+
+ auto* materialized = Materialize(ValueExpression(attr->expr));
+ if (!materialized) {
+ return utils::Failure;
+ }
+ if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
+ AddError("@group must be an i32 or u32 value", attr->source);
+ return utils::Failure;
+ }
+
+ auto const_value = materialized->ConstantValue();
+ auto value = const_value->ValueAs<AInt>();
+ if (value < 0) {
+ AddError("@group value must be non-negative", attr->source);
+ return utils::Failure;
+ }
+ return static_cast<uint32_t>(value);
+}
+
+utils::Result<sem::WorkgroupSize> Resolver::WorkgroupAttribute(
+ const ast::WorkgroupAttribute* attr) {
+ // Set work-group size defaults.
+ sem::WorkgroupSize ws;
+ for (size_t i = 0; i < 3; i++) {
+ ws[i] = 1;
+ }
+
+ auto values = attr->Values();
+ utils::Vector<const sem::ValueExpression*, 3> args;
+ utils::Vector<const type::Type*, 3> arg_tys;
+
+ constexpr const char* kErrBadExpr =
+ "workgroup_size argument must be a constant or override-expression of type "
+ "abstract-integer, i32 or u32";
+
+ for (size_t i = 0; i < 3; i++) {
+ // Each argument to this attribute can either be a literal, an identifier for a
+ // module-scope constants, a const-expression, or nullptr if not specified.
+ auto* value = values[i];
+ if (!value) {
+ break;
+ }
+ const auto* expr = ValueExpression(value);
+ if (!expr) {
+ return utils::Failure;
+ }
+ auto* ty = expr->Type();
+ if (!ty->IsAnyOf<type::I32, type::U32, type::AbstractInt>()) {
+ AddError(kErrBadExpr, value->source);
+ return utils::Failure;
+ }
+
+ if (expr->Stage() != sem::EvaluationStage::kConstant &&
+ expr->Stage() != sem::EvaluationStage::kOverride) {
+ AddError(kErrBadExpr, value->source);
+ return utils::Failure;
+ }
+
+ args.Push(expr);
+ arg_tys.Push(ty);
+ }
+
+ auto* common_ty = type::Type::Common(arg_tys);
+ if (!common_ty) {
+ AddError("workgroup_size arguments must be of the same type, either i32 or u32",
+ attr->source);
+ return utils::Failure;
+ }
+
+ // If all arguments are abstract-integers, then materialize to i32.
+ if (common_ty->Is<type::AbstractInt>()) {
+ common_ty = builder_->create<type::I32>();
+ }
+
+ for (size_t i = 0; i < args.Length(); i++) {
+ auto* materialized = Materialize(args[i], common_ty);
+ if (!materialized) {
+ return utils::Failure;
+ }
+ if (auto* value = materialized->ConstantValue()) {
+ if (value->ValueAs<AInt>() < 1) {
+ AddError("workgroup_size argument must be at least 1", values[i]->source);
+ return utils::Failure;
+ }
+ ws[i] = value->ValueAs<u32>();
+ } else {
+ ws[i] = std::nullopt;
+ }
+ }
+
+ uint64_t total_size = static_cast<uint64_t>(ws[0].value_or(1));
+ for (size_t i = 1; i < 3; i++) {
+ total_size *= static_cast<uint64_t>(ws[i].value_or(1));
+ if (total_size > 0xffffffff) {
+ AddError("total workgroup grid size cannot exceed 0xffffffff", values[i]->source);
+ return utils::Failure;
+ }
+ }
+
+ return ws;
+}
+
+utils::Result<tint::builtin::BuiltinValue> Resolver::BuiltinAttribute(
+ const ast::BuiltinAttribute* attr) {
auto* builtin_expr = BuiltinValueExpression(attr->builtin);
if (!builtin_expr) {
- return false;
+ return utils::Failure;
}
// Apply the resolved tint::sem::BuiltinEnumExpression<tint::builtin::BuiltinValue> to the
// attribute.
builder_->Sem().Add(attr, builtin_expr);
+ return builtin_expr->Value();
+}
+
+bool Resolver::DiagnosticAttribute(const ast::DiagnosticAttribute* attr) {
+ return DiagnosticControl(attr->control);
+}
+
+bool Resolver::StageAttribute(const ast::StageAttribute*) {
+ return true;
+}
+
+bool Resolver::MustUseAttribute(const ast::MustUseAttribute*) {
+ return true;
+}
+
+bool Resolver::InvariantAttribute(const ast::InvariantAttribute*) {
+ return true;
+}
+
+bool Resolver::StrideAttribute(const ast::StrideAttribute*) {
return true;
}
@@ -3626,24 +3851,30 @@
return false;
}
- for (auto* attr : attributes) {
- Mark(attr);
- if (auto* sd = attr->As<ast::StrideAttribute>()) {
- // If the element type is not plain, then el_ty->Align() may be 0, in which case we
- // could get a DBZ in ArrayStrideAttribute(). In this case, validation will error
- // about the invalid array element type (which is tested later), so this is just a
- // seatbelt.
- if (IsPlain(el_ty)) {
- explicit_stride = sd->stride;
- if (!validator_.ArrayStrideAttribute(sd, el_ty->Size(), el_ty->Align())) {
- return false;
+ for (auto* attribute : attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute, //
+ [&](const ast::StrideAttribute* attr) {
+ // If the element type is not plain, then el_ty->Align() may be 0, in which case we
+ // could get a DBZ in ArrayStrideAttribute(). In this case, validation will error
+ // about the invalid array element type (which is tested later), so this is just a
+ // seatbelt.
+ if (IsPlain(el_ty)) {
+ explicit_stride = attr->stride;
+ if (!validator_.ArrayStrideAttribute(attr, el_ty->Size(), el_ty->Align())) {
+ return false;
+ }
}
- }
- continue;
+ return true;
+ },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "array types");
+ return false;
+ });
+ if (!ok) {
+ return false;
}
-
- AddError("attribute is not valid for array types", attr->source);
- return false;
}
return true;
@@ -3727,8 +3958,18 @@
if (!validator_.NoDuplicateAttributes(str->attributes)) {
return nullptr;
}
- for (auto* attr : str->attributes) {
- Mark(attr);
+
+ for (auto* attribute : str->attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute, [&](const ast::InternalAttribute* attr) { return InternalAttribute(attr); },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "struct declarations");
+ return false;
+ });
+ if (!ok) {
+ return nullptr;
+ }
}
utils::Vector<const sem::StructMember*, 8> sem_members;
@@ -3781,88 +4022,87 @@
bool has_align_attr = false;
bool has_size_attr = false;
std::optional<uint32_t> location;
- for (auto* attr : member->attributes) {
- if (!Attribute(attr)) {
- return nullptr;
- }
+ for (auto* attribute : member->attributes) {
+ Mark(attribute);
bool ok = Switch(
- attr, //
- [&](const ast::StructMemberOffsetAttribute* o) {
- // Offset attributes are not part of the WGSL spec, but are emitted
- // by the SPIR-V reader.
+ attribute, //
+ [&](const ast::StructMemberOffsetAttribute* attr) {
+ // Offset attributes are not part of the WGSL spec, but are emitted by the
+ // SPIR-V reader.
+
ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant,
"@offset value"};
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- auto* materialized = Materialize(ValueExpression(o->expr));
+ auto* materialized = Materialize(ValueExpression(attr->expr));
if (!materialized) {
return false;
}
auto const_value = materialized->ConstantValue();
if (!const_value) {
- AddError("@offset must be constant expression", o->expr->source);
+ AddError("@offset must be constant expression", attr->expr->source);
return false;
}
offset = const_value->ValueAs<uint64_t>();
if (offset < struct_size) {
- AddError("offsets must be in ascending order", o->source);
+ AddError("offsets must be in ascending order", attr->source);
return false;
}
has_offset_attr = true;
return true;
},
- [&](const ast::StructMemberAlignAttribute* a) {
+ [&](const ast::StructMemberAlignAttribute* attr) {
ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@align"};
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- auto* materialized = Materialize(ValueExpression(a->expr));
+ auto* materialized = Materialize(ValueExpression(attr->expr));
if (!materialized) {
return false;
}
if (!materialized->Type()->IsAnyOf<type::I32, type::U32>()) {
- AddError("@align must be an i32 or u32 value", a->source);
+ AddError("@align must be an i32 or u32 value", attr->source);
return false;
}
auto const_value = materialized->ConstantValue();
if (!const_value) {
- AddError("@align must be constant expression", a->source);
+ AddError("@align must be constant expression", attr->source);
return false;
}
auto value = const_value->ValueAs<AInt>();
if (value <= 0 || !utils::IsPowerOfTwo(value)) {
AddError("@align value must be a positive, power-of-two integer",
- a->source);
+ attr->source);
return false;
}
align = u32(value);
has_align_attr = true;
return true;
},
- [&](const ast::StructMemberSizeAttribute* s) {
+ [&](const ast::StructMemberSizeAttribute* attr) {
ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "@size"};
TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
- auto* materialized = Materialize(ValueExpression(s->expr));
+ auto* materialized = Materialize(ValueExpression(attr->expr));
if (!materialized) {
return false;
}
if (!materialized->Type()->IsAnyOf<type::U32, type::I32>()) {
- AddError("@size must be an i32 or u32 value", s->source);
+ AddError("@size must be an i32 or u32 value", attr->source);
return false;
}
auto const_value = materialized->ConstantValue();
if (!const_value) {
- AddError("@size must be constant expression", s->expr->source);
+ AddError("@size must be constant expression", attr->expr->source);
return false;
}
{
auto value = const_value->ValueAs<AInt>();
if (value <= 0) {
- AddError("@size must be a positive integer", s->source);
+ AddError("@size must be a positive integer", attr->source);
return false;
}
}
@@ -3870,24 +4110,36 @@
if (value < size) {
AddError("@size must be at least as big as the type's size (" +
std::to_string(size) + ")",
- s->source);
+ attr->source);
return false;
}
size = u32(value);
has_size_attr = true;
return true;
},
- [&](const ast::LocationAttribute* loc_attr) {
- auto value = LocationAttribute(loc_attr);
+ [&](const ast::LocationAttribute* attr) {
+ auto value = LocationAttribute(attr);
if (!value) {
return false;
}
location = value.Get();
return true;
},
+ [&](const ast::BuiltinAttribute* attr) -> bool { return BuiltinAttribute(attr); },
+ [&](const ast::InterpolateAttribute* attr) { return InterpolateAttribute(attr); },
+ [&](const ast::InvariantAttribute* attr) { return InvariantAttribute(attr); },
+ [&](const ast::StrideAttribute* attr) {
+ if (validator_.IsValidationEnabled(
+ member->attributes, ast::DisabledValidation::kIgnoreStrideAttribute)) {
+ ErrorInvalidAttribute(attribute, "struct members");
+ return false;
+ }
+ return StrideAttribute(attr);
+ },
+ [&](const ast::InternalAttribute* attr) { return InternalAttribute(attr); },
[&](Default) {
- // The validator will check attributes can be applied to the struct member.
- return true;
+ ErrorInvalidAttribute(attribute, "struct members");
+ return false;
});
if (!ok) {
return nullptr;
@@ -4049,14 +4301,16 @@
}
// Handle switch body attributes.
- for (auto* attr : stmt->body_attributes) {
- Mark(attr);
- if (auto* dc = attr->As<ast::DiagnosticAttribute>()) {
- if (!DiagnosticControl(dc->control)) {
+ for (auto* attribute : stmt->body_attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute,
+ [&](const ast::DiagnosticAttribute* attr) { return DiagnosticAttribute(attr); },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, "switch body");
return false;
- }
- } else {
- AddError("attribute is not valid for switch body", attr->source);
+ });
+ if (!ok) {
return false;
}
}
@@ -4099,14 +4353,6 @@
return false;
}
- for (auto* attr : stmt->variable->attributes) {
- Mark(attr);
- if (!attr->Is<ast::InternalAttribute>()) {
- AddError("attributes are not valid on local variables", attr->source);
- return false;
- }
- }
-
current_compound_statement_->AddDecl(variable->As<sem::LocalVariable>());
if (auto* ctor = variable->Initializer()) {
@@ -4339,16 +4585,16 @@
// Helper to handle attributes that are supported on certain types of statement.
auto handle_attributes = [&](auto* stmt, sem::Statement* sem_stmt, const char* use) {
- for (auto* attr : stmt->attributes) {
- Mark(attr);
- if (auto* dc = attr->template As<ast::DiagnosticAttribute>()) {
- if (!DiagnosticControl(dc->control)) {
+ for (auto* attribute : stmt->attributes) {
+ Mark(attribute);
+ bool ok = Switch(
+ attribute, //
+ [&](const ast::DiagnosticAttribute* attr) { return DiagnosticAttribute(attr); },
+ [&](Default) {
+ ErrorInvalidAttribute(attribute, use);
return false;
- }
- } else {
- utils::StringStream ss;
- ss << "attribute is not valid for " << use;
- AddError(ss.str(), attr->source);
+ });
+ if (!ok) {
return false;
}
}
@@ -4451,6 +4697,10 @@
sem_.NoteDeclarationSource(resolved.Node());
}
+void Resolver::ErrorInvalidAttribute(const ast::Attribute* attr, std::string_view use) {
+ AddError("@" + attr->Name() + " is not valid for " + std::string(use), attr->source);
+}
+
void Resolver::AddError(const std::string& msg, const Source& source) const {
diagnostics_.add_error(diag::System::Resolver, msg, source);
}
diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h
index edc088a..b26ad96 100644
--- a/src/tint/resolver/resolver.h
+++ b/src/tint/resolver/resolver.h
@@ -312,13 +312,45 @@
/// current_function_
bool WorkgroupSize(const ast::Function*);
- /// Resolves the attribute @p attr
- /// @returns true on success, false on failure
- bool Attribute(const ast::Attribute* attr);
-
/// Resolves the `@builtin` attribute @p attr
+ /// @returns the builtin value on success
+ utils::Result<tint::builtin::BuiltinValue> BuiltinAttribute(const ast::BuiltinAttribute* attr);
+
+ /// Resolves the `@location` attribute @p attr
+ /// @returns the location value on success.
+ utils::Result<uint32_t> LocationAttribute(const ast::LocationAttribute* attr);
+
+ /// Resolves the `@binding` attribute @p attr
+ /// @returns the binding value on success.
+ utils::Result<uint32_t> BindingAttribute(const ast::BindingAttribute* attr);
+
+ /// Resolves the `@group` attribute @p attr
+ /// @returns the group value on success.
+ utils::Result<uint32_t> GroupAttribute(const ast::GroupAttribute* attr);
+
+ /// Resolves the `@workgroup_size` attribute @p attr
+ /// @returns the workgroup size on success.
+ utils::Result<sem::WorkgroupSize> WorkgroupAttribute(const ast::WorkgroupAttribute* attr);
+
+ /// Resolves the `@diagnostic` attribute @p attr
/// @returns true on success, false on failure
- bool BuiltinAttribute(const ast::BuiltinAttribute* attr);
+ bool DiagnosticAttribute(const ast::DiagnosticAttribute* attr);
+
+ /// Resolves the stage attribute @p attr
+ /// @returns true on success, false on failure
+ bool StageAttribute(const ast::StageAttribute* attr);
+
+ /// Resolves the `@must_use` attribute @p attr
+ /// @returns true on success, false on failure
+ bool MustUseAttribute(const ast::MustUseAttribute* attr);
+
+ /// Resolves the `@invariant` attribute @p attr
+ /// @returns true on success, false on failure
+ bool InvariantAttribute(const ast::InvariantAttribute*);
+
+ /// Resolves the `@stride` attribute @p attr
+ /// @returns true on success, false on failure
+ bool StrideAttribute(const ast::StrideAttribute*);
/// Resolves the `@interpolate` attribute @p attr
/// @returns true on success, false on failure
@@ -427,12 +459,11 @@
/// nullptr is returned.
/// @note the caller is expected to validate the parameter
/// @param param the AST parameter
+ /// @param func the AST function that owns the parameter
/// @param index the index of the parameter
- sem::Parameter* Parameter(const ast::Parameter* param, uint32_t index);
-
- /// @returns the location value for a `@location` attribute, validating the value's range and
- /// type.
- utils::Result<uint32_t> LocationAttribute(const ast::LocationAttribute* attr);
+ sem::Parameter* Parameter(const ast::Parameter* param,
+ const ast::Function* func,
+ uint32_t index);
/// Records the address space usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the
@@ -497,6 +528,11 @@
const ResolvedIdentifier& resolved,
std::string_view wanted);
+ /// Raises an error that the attribute is not valid for the given use.
+ /// @param attr the invalue attribute
+ /// @param use the thing that the attribute was applied to
+ void ErrorInvalidAttribute(const ast::Attribute* attr, std::string_view use);
+
/// Adds the given error message to the diagnostics
void AddError(const std::string& msg, const Source& source) const;
diff --git a/src/tint/resolver/unresolved_identifier_test.cc b/src/tint/resolver/unresolved_identifier_test.cc
index e52b858..5802005 100644
--- a/src/tint/resolver/unresolved_identifier_test.cc
+++ b/src/tint/resolver/unresolved_identifier_test.cc
@@ -43,7 +43,7 @@
Func("f",
utils::Vector{
Param("p", ty.i32(), utils::Vector{Builtin(Expr(Source{{12, 34}}, "positon"))})},
- ty.void_(), utils::Empty);
+ ty.void_(), utils::Empty, utils::Vector{Stage(ast::PipelineStage::kVertex)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved builtin value 'positon'
diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc
index 97eae88..c1ba335 100644
--- a/src/tint/resolver/validator.cc
+++ b/src/tint/resolver/validator.cc
@@ -606,32 +606,10 @@
return false;
}
- for (auto* attr : decl->attributes) {
- bool is_shader_io_attribute =
- attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute,
- ast::InvariantAttribute, ast::LocationAttribute>();
- bool has_io_address_space = global->AddressSpace() == builtin::AddressSpace::kIn ||
- global->AddressSpace() == builtin::AddressSpace::kOut;
- if (!attr->IsAnyOf<ast::BindingAttribute, ast::GroupAttribute,
- ast::InternalAttribute>() &&
- (!is_shader_io_attribute || !has_io_address_space)) {
- AddError("attribute '" + attr->Name() + "' is not valid for module-scope 'var'",
- attr->source);
- return false;
- }
- }
-
return Var(global);
},
[&](const ast::Override*) { return Override(global, override_ids); },
- [&](const ast::Const*) {
- if (!decl->attributes.IsEmpty()) {
- AddError("attribute is not valid for module-scope 'const' declaration",
- decl->attributes[0]->source);
- return false;
- }
- return Const(global);
- },
+ [&](const ast::Const*) { return Const(global); },
[&](Default) {
TINT_ICE(Resolver, diagnostics_)
<< "Validator::GlobalVariable() called with a unknown variable type: "
@@ -773,9 +751,6 @@
ast::GetAttribute<ast::IdAttribute>((*var)->Declaration()->attributes)->source);
return false;
}
- } else {
- AddError("attribute is not valid for 'override' declaration", attr->source);
- return false;
}
}
@@ -792,28 +767,13 @@
return true;
}
-bool Validator::Parameter(const ast::Function* func, const sem::Variable* var) const {
+bool Validator::Parameter(const sem::Variable* var) const {
auto* decl = var->Declaration();
if (IsValidationDisabled(decl->attributes, ast::DisabledValidation::kFunctionParameter)) {
return true;
}
- for (auto* attr : decl->attributes) {
- if (!func->IsEntryPoint() && !attr->Is<ast::InternalAttribute>()) {
- AddError("attribute is not valid for non-entry point function parameters",
- attr->source);
- return false;
- }
- if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::InvariantAttribute, ast::LocationAttribute,
- ast::InterpolateAttribute, ast::InternalAttribute>() &&
- (IsValidationEnabled(decl->attributes,
- ast::DisabledValidation::kEntryPointParameter))) {
- AddError("attribute is not valid for function parameters", attr->source);
- return false;
- }
- }
-
if (auto* ref = var->Type()->As<type::Pointer>()) {
if (IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreAddressSpace)) {
bool ok = false;
@@ -1028,14 +988,7 @@
}
return true;
},
- [&](Default) {
- if (!attr->IsAnyOf<ast::DiagnosticAttribute, ast::StageAttribute,
- ast::InternalAttribute>()) {
- AddError("attribute is not valid for functions", attr->source);
- return false;
- }
- return true;
- });
+ [&](Default) { return true; });
if (!ok) {
return false;
}
@@ -1069,24 +1022,6 @@
TINT_ICE(Resolver, diagnostics_)
<< "Function " << decl->name->symbol.Name() << " has no body";
}
-
- for (auto* attr : decl->return_type_attributes) {
- if (!decl->IsEntryPoint()) {
- AddError("attribute is not valid for non-entry point function return types",
- attr->source);
- return false;
- }
- if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::InternalAttribute,
- ast::LocationAttribute, ast::InterpolateAttribute,
- ast::InvariantAttribute>() &&
- (IsValidationEnabled(decl->attributes,
- ast::DisabledValidation::kEntryPointParameter) &&
- IsValidationEnabled(decl->attributes,
- ast::DisabledValidation::kFunctionParameter))) {
- AddError("attribute is not valid for entry point return types", attr->source);
- return false;
- }
- }
}
if (decl->IsEntryPoint()) {
@@ -1196,7 +1131,7 @@
if (is_invalid_compute_shader_attribute) {
std::string input_or_output =
param_or_ret == ParamOrRetType::kParameter ? "inputs" : "output";
- AddError("attribute is not valid for compute shader " + input_or_output,
+ AddError("@" + attr->Name() + " is not valid for compute shader " + input_or_output,
attr->source);
return false;
}
@@ -2205,24 +2140,7 @@
}
return true;
},
- [&](Default) {
- if (!attr->IsAnyOf<ast::BuiltinAttribute, //
- ast::InternalAttribute, //
- ast::InterpolateAttribute, //
- ast::InvariantAttribute, //
- ast::LocationAttribute, //
- ast::StructMemberOffsetAttribute, //
- ast::StructMemberAlignAttribute>()) {
- if (attr->Is<ast::StrideAttribute>() &&
- IsValidationDisabled(member->Declaration()->attributes,
- ast::DisabledValidation::kIgnoreStrideAttribute)) {
- return true;
- }
- AddError("attribute is not valid for structure members", attr->source);
- return false;
- }
- return true;
- });
+ [&](Default) { return true; });
if (!ok) {
return false;
}
@@ -2241,13 +2159,6 @@
}
}
- for (auto* attr : str->Declaration()->attributes) {
- if (!(attr->IsAnyOf<ast::InternalAttribute>())) {
- AddError("attribute is not valid for struct declarations", attr->source);
- return false;
- }
- }
-
return true;
}
@@ -2260,7 +2171,8 @@
const bool is_input) const {
std::string inputs_or_output = is_input ? "inputs" : "output";
if (stage == ast::PipelineStage::kCompute) {
- AddError("attribute is not valid for compute shader " + inputs_or_output, loc_attr->source);
+ AddError("@" + loc_attr->Name() + " is not valid for compute shader " + inputs_or_output,
+ loc_attr->source);
return false;
}
diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h
index e0e3051..1ab38bc 100644
--- a/src/tint/resolver/validator.h
+++ b/src/tint/resolver/validator.h
@@ -348,10 +348,9 @@
bool Matrix(const type::Type* el_ty, const Source& source) const;
/// Validates a function parameter
- /// @param func the function the variable is for
/// @param var the variable to validate
/// @returns true on success, false otherwise
- bool Parameter(const ast::Function* func, const sem::Variable* var) const;
+ bool Parameter(const sem::Variable* var) const;
/// Validates a return
/// @param ret the return statement to validate
diff --git a/src/tint/resolver/variable_test.cc b/src/tint/resolver/variable_test.cc
index 7af6027..947f70c 100644
--- a/src/tint/resolver/variable_test.cc
+++ b/src/tint/resolver/variable_test.cc
@@ -383,7 +383,7 @@
}
////////////////////////////////////////////////////////////////////////////////////////////////////
-// Function-scope 'let'
+// 'let' declaration
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(ResolverVariableTest, LocalLet) {
// struct S { i : i32; }