validation: validate builtin pipeline stage and Input/Output
Bug: tint:957
Change-Id: I5f509e61501b39f2a0b3bc10a204ae1f39a0d460
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57105
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index 3dc70ab..3c9b935 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -1059,11 +1059,11 @@
TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) {
auto* in_var0 =
- Param("in_var0", ty.u32(), {Builtin(ast::Builtin::kInstanceIndex)});
- auto* in_var1 = Param("in_var1", ty.u32(), {Location(0u)});
- Func("foo", {in_var0, in_var1}, ty.u32(), {Return("in_var1")},
+ Param("in_var0", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)});
+ auto* in_var1 = Param("in_var1", ty.f32(), {Location(0u)});
+ Func("foo", {in_var0, in_var1}, ty.f32(), {Return("in_var1")},
{Stage(ast::PipelineStage::kFragment)},
- {Builtin(ast::Builtin::kSampleMask)});
+ {Builtin(ast::Builtin::kFragDepth)});
Inspector& inspector = Build();
auto result = inspector.GetEntryPoints();
@@ -1075,7 +1075,7 @@
EXPECT_EQ("in_var1", result[0].input_variables[0].name);
EXPECT_TRUE(result[0].input_variables[0].has_location_decoration);
EXPECT_EQ(0u, result[0].input_variables[0].location_decoration);
- EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[0].component_type);
+ EXPECT_EQ(ComponentType::kFloat, result[0].input_variables[0].component_type);
ASSERT_EQ(0u, result[0].output_variables.size());
}
diff --git a/src/resolver/builtins_validation_test.cc b/src/resolver/builtins_validation_test.cc
index 6c06ce0..46a2b66 100644
--- a/src/resolver/builtins_validation_test.cc
+++ b/src/resolver/builtins_validation_test.cc
@@ -16,9 +16,209 @@
#include "src/resolver/resolver_test_helper.h"
namespace tint {
+namespace resolver {
namespace {
+
+template <typename T>
+using DataType = builder::DataType<T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <typename T>
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
class ResolverBuiltinsValidationTest : public resolver::TestHelper,
public testing::Test {};
+namespace TypeTemp {
+struct Params {
+ builder::ast_type_func_ptr type;
+ ast::Builtin builtin;
+ ast::PipelineStage stage;
+ bool is_valid;
+};
+
+template <typename T>
+constexpr Params ParamsFor(ast::Builtin builtin,
+ ast::PipelineStage stage,
+ bool is_valid) {
+ return Params{DataType<T>::AST, builtin, stage, is_valid};
+}
+static constexpr Params cases[] = {
+ ParamsFor<u32>(ast::Builtin::kVertexIndex,
+ ast::PipelineStage::kVertex,
+ true),
+ ParamsFor<u32>(ast::Builtin::kVertexIndex,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<u32>(ast::Builtin::kVertexIndex,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+ ast::PipelineStage::kVertex,
+ true),
+ ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<bool>(ast::Builtin::kFrontFacing,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<bool>(ast::Builtin::kFrontFacing,
+ ast::PipelineStage::kFragment,
+ true),
+ ParamsFor<bool>(ast::Builtin::kFrontFacing,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+ ast::PipelineStage::kFragment,
+ false),
+ ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+ ast::PipelineStage::kCompute,
+ true),
+
+ ParamsFor<u32>(ast::Builtin::kSampleIndex,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<u32>(ast::Builtin::kSampleIndex,
+ ast::PipelineStage::kFragment,
+ true),
+ ParamsFor<u32>(ast::Builtin::kSampleIndex,
+ ast::PipelineStage::kCompute,
+ false),
+
+ ParamsFor<u32>(ast::Builtin::kSampleMask,
+ ast::PipelineStage::kVertex,
+ false),
+ ParamsFor<u32>(ast::Builtin::kSampleMask,
+ ast::PipelineStage::kFragment,
+ true),
+ ParamsFor<u32>(ast::Builtin::kSampleMask,
+ ast::PipelineStage::kCompute,
+ false),
+};
+
+using ResolverBuiltinsStageTest = ResolverTestWithParam<Params>;
+TEST_P(ResolverBuiltinsStageTest, All_input) {
+ const Params& params = GetParam();
+
+ auto* p = Global("p", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+ auto* input =
+ Param("input", params.type(*this),
+ ast::DecorationList{Builtin(Source{{12, 34}}, params.builtin)});
+ switch (params.stage) {
+ case ast::PipelineStage::kVertex:
+ Func("main", {input}, ty.vec4<f32>(), {Return(p)},
+ {Stage(ast::PipelineStage::kVertex)},
+ {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
+ break;
+ case ast::PipelineStage::kFragment:
+ Func("main", {input}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)}, {});
+ break;
+ case ast::PipelineStage::kCompute:
+ Func("main", {input}, ty.void_(), {},
+ ast::DecorationList{Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(1)});
+ break;
+ default:
+ break;
+ }
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ std::stringstream err;
+ err << "12:34 error: builtin(" << params.builtin << ")";
+ err << " cannot be used in input of " << params.stage << " pipeline stage";
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), err.str());
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
+ ResolverBuiltinsStageTest,
+ testing::ValuesIn(cases));
+
+TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) {
+ // [[stage(fragment)]]
+ // fn fs_main(
+ // [[builtin(kFragDepth)]] fd: f32,
+ // ) -> [[location(0)]] f32 { return 1.0; }
+ auto* fd = Param(
+ "fd", ty.f32(),
+ ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
+ Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)},
+ ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
+ {Location(0)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: builtin(frag_depth) cannot be used in input of "
+ "fragment pipeline stage");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) {
+ // Struct MyInputs {
+ // [[builtin(front_facing)]] ff: bool;
+ // };
+ // [[stage(fragment)]]
+ // fn fragShader(arg: MyInputs) -> [[location(0)]] f32 { return 1.0; }
+
+ auto* s = Structure(
+ "MyInputs", {Member("frag_depth", ty.f32(),
+ ast::DecorationList{Builtin(
+ Source{{12, 34}}, ast::Builtin::kFragDepth)})});
+
+ Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: builtin(frag_depth) cannot be used in input of fragment "
+ "pipeline stage\nnote: while analysing entry point fragShader");
+}
+} // namespace TypeTemp
TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) {
// struct MyInputs {
@@ -170,15 +370,12 @@
TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail) {
// [[stage(fragment)]]
- // fn fs_main(
- // [[builtin(kFragDepth)]] fd: f32,
- // ) -> [[location(0)]] f32 { return 1.0; }
- auto* fd = Param(
- "fd", ty.i32(),
+ // fn fs_main() -> [[builtin(kFragDepth)]] f32 { var fd: i32; return fd; }
+ auto* fd = Var("fd", ty.i32());
+ Func(
+ "fs_main", {}, ty.i32(), {Decl(fd), Return(fd)},
+ ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
- Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)},
- ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
- {Location(0)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: store type of builtin(frag_depth) must be 'f32'");
@@ -227,44 +424,43 @@
// fn fs_main(
// [[builtin(kPosition)]] p: vec4<f32>,
// [[builtin(front_facing)]] ff: bool,
- // [[builtin(frag_depth)]] fd: f32,
// [[builtin(sample_index)]] si: u32,
// [[builtin(sample_mask)]] sm : u32
- // ) -> [[location(0)]] f32 { return 1.0; }
+ // ) -> [[builtin(frag_depth)]] f32 { var fd: f32; return fd; }
auto* p = Param("p", ty.vec4<f32>(),
ast::DecorationList{Builtin(ast::Builtin::kPosition)});
auto* ff = Param("ff", ty.bool_(),
ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)});
- auto* fd = Param("fd", ty.f32(),
- ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
auto* si = Param("si", ty.u32(),
ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)});
auto* sm = Param("sm", ty.u32(),
ast::DecorationList{Builtin(ast::Builtin::kSampleMask)});
- Func(
- "fs_main", ast::VariableList{p, ff, fd, si, sm}, ty.f32(), {Return(1.0f)},
- ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+ auto* var_fd = Var("fd", ty.f32());
+ Func("fs_main", ast::VariableList{p, ff, si, sm}, ty.f32(),
+ {Decl(var_fd), Return(var_fd)},
+ ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
+ ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass) {
// [[stage(vertex)]]
// fn main(
- // [[builtin(kVertexIndex)]] vi : u32,
- // [[builtin(kInstanceIndex)]] ii : u32,
- // [[builtin(kPosition)]] p :vec4<f32>
- // ) {}
+ // [[builtin(vertex_index)]] vi : u32,
+ // [[builtin(instance_index)]] ii : u32,
+ // ) -> [[builtin(position)]] vec4<f32> { var p :vec4<f32>; return p; }
auto* vi = Param("vi", ty.u32(),
ast::DecorationList{
Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)});
- auto* p = Param("p", ty.vec4<f32>(),
- ast::DecorationList{Builtin(ast::Builtin::kPosition)});
+
auto* ii = Param("ii", ty.u32(),
ast::DecorationList{Builtin(Source{{12, 34}},
ast::Builtin::kInstanceIndex)});
- Func("main", ast::VariableList{vi, ii, p}, ty.vec4<f32>(),
+ auto* p = Var("p", ty.vec4<f32>());
+ Func("main", ast::VariableList{vi, ii}, ty.vec4<f32>(),
{
- Return(Expr(p)),
+ Decl(p),
+ Return(p),
},
ast::DecorationList{Stage(ast::PipelineStage::kVertex)},
ast::DecorationList{Builtin(ast::Builtin::kPosition)});
@@ -369,7 +565,6 @@
TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) {
// Struct MyInputs {
// [[builtin(kPosition)]] p: vec4<f32>;
- // [[builtin(front_facing)]] ff: bool;
// [[builtin(frag_depth)]] fd: f32;
// [[builtin(sample_index)]] si: u32;
// [[builtin(sample_mask)]] sm : u32;;
@@ -383,8 +578,6 @@
ast::DecorationList{Builtin(ast::Builtin::kPosition)}),
Member("front_facing", ty.bool_(),
ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)}),
- Member("frag_depth", ty.f32(),
- ast::DecorationList{Builtin(ast::Builtin::kFragDepth)}),
Member("sample_index", ty.u32(),
ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)}),
Member("sample_mask", ty.u32(),
@@ -1006,4 +1199,5 @@
"pack2x16float"));
} // namespace
+} // namespace resolver
} // namespace tint
diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc
index ba261a8..61f8d36 100644
--- a/src/resolver/decoration_validation_test.cc
+++ b/src/resolver/decoration_validation_test.cc
@@ -145,7 +145,8 @@
} else {
EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(),
- "error: decoration is not valid for function parameters");
+ "error: decoration is not valid for non-entry point function "
+ "parameters");
}
}
INSTANTIATE_TEST_SUITE_P(
@@ -244,7 +245,8 @@
} else {
EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(),
- "error: decoration is not valid for function return types");
+ "error: decoration is not valid for non-entry point function "
+ "return types");
}
}
INSTANTIATE_TEST_SUITE_P(
diff --git a/src/resolver/entry_point_validation_test.cc b/src/resolver/entry_point_validation_test.cc
index 6b63223..bd0201f 100644
--- a/src/resolver/entry_point_validation_test.cc
+++ b/src/resolver/entry_point_validation_test.cc
@@ -289,16 +289,6 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
-TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Builtin) {
- // [[stage(fragment)]]
- // fn main([[builtin(frag_depth)]] param : f32) {}
- auto* param = Param("param", ty.f32(), {Builtin(ast::Builtin::kFragDepth)});
- Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
- {Stage(ast::PipelineStage::kFragment)});
-
- EXPECT_TRUE(r()->Resolve()) << r()->error();
-}
-
TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) {
// [[stage(fragment)]]
// fn main(param : f32) {}
@@ -313,10 +303,10 @@
TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) {
// [[stage(fragment)]]
- // fn main([[location(0)]] [[builtin(vertex_index)]] param : u32) {}
+ // fn main([[location(0)]] [[builtin(sample_index)]] param : u32) {}
auto* param = Param("param", ty.u32(),
{Location(Source{{13, 43}}, 0),
- Builtin(Source{{14, 52}}, ast::Builtin::kVertexIndex)});
+ Builtin(Source{{14, 52}}, ast::Builtin::kSampleIndex)});
Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
{Stage(ast::PipelineStage::kFragment)});
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 7a4c7a7..7002b5a 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -928,20 +928,15 @@
for (auto* deco : info->declaration->decorations()) {
if (!func->IsEntryPoint() && !deco->Is<ast::InternalDecoration>()) {
- AddError("decoration is not valid for function parameters",
- deco->source());
+ AddError(
+ "decoration is not valid for non-entry point function parameters",
+ deco->source());
return false;
- }
-
- if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
- if (!ValidateBuiltinDecoration(builtin, info->type)) {
- return false;
- }
} else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
if (!ValidateInterpolateDecoration(interpolate, info->type)) {
return false;
}
- } else if (!deco->IsAnyOf<ast::LocationDecoration,
+ } else if (!deco->IsAnyOf<ast::LocationDecoration, ast::BuiltinDecoration,
ast::InternalDecoration>() &&
(IsValidationEnabled(
info->declaration->decorations(),
@@ -989,10 +984,25 @@
}
bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
- const sem::Type* storage_type) {
+ const sem::Type* storage_type,
+ const bool is_input) {
auto* type = storage_type->UnwrapRef();
+ const auto stage = current_function_
+ ? current_function_->declaration->pipeline_stage()
+ : ast::PipelineStage::kNone;
+ std::stringstream stage_name;
+ stage_name << stage;
+ bool is_stage_mismatch = false;
switch (deco->value()) {
case ast::Builtin::kPosition:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment && is_input) &&
+ !(stage == ast::PipelineStage::kVertex && !is_input)) {
+ AddError(deco_to_str(deco) + " cannot be used in " +
+ (is_input ? "input of " : "output of ") +
+ stage_name.str() + " pipeline stage",
+ deco->source());
+ }
if (!(type->is_float_vector() && type->As<sem::Vector>()->size() == 4)) {
AddError("store type of " + deco_to_str(deco) + " must be 'vec4<f32>'",
deco->source());
@@ -1002,6 +1012,10 @@
case ast::Builtin::kGlobalInvocationId:
case ast::Builtin::kLocalInvocationId:
case ast::Builtin::kWorkgroupId:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kCompute && is_input)) {
+ is_stage_mismatch = true;
+ }
if (!(type->is_unsigned_integer_vector() &&
type->As<sem::Vector>()->size() == 3)) {
AddError("store type of " + deco_to_str(deco) + " must be 'vec3<u32>'",
@@ -1010,6 +1024,10 @@
}
break;
case ast::Builtin::kFragDepth:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment && !is_input)) {
+ is_stage_mismatch = true;
+ }
if (!type->Is<sem::F32>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'f32'",
deco->source());
@@ -1017,6 +1035,10 @@
}
break;
case ast::Builtin::kFrontFacing:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment && is_input)) {
+ is_stage_mismatch = true;
+ }
if (!type->Is<sem::Bool>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'bool'",
deco->source());
@@ -1024,10 +1046,44 @@
}
break;
case ast::Builtin::kLocalInvocationIndex:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kCompute && is_input)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::U32>()) {
+ AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
+ deco->source());
+ return false;
+ }
+ break;
case ast::Builtin::kVertexIndex:
case ast::Builtin::kInstanceIndex:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kVertex && is_input)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::U32>()) {
+ AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
+ deco->source());
+ return false;
+ }
+ break;
case ast::Builtin::kSampleMask:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment)) {
+ is_stage_mismatch = true;
+ }
+ if (!type->Is<sem::U32>()) {
+ AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
+ deco->source());
+ return false;
+ }
+ break;
case ast::Builtin::kSampleIndex:
+ if (stage != ast::PipelineStage::kNone &&
+ !(stage == ast::PipelineStage::kFragment && is_input)) {
+ is_stage_mismatch = true;
+ }
if (!type->Is<sem::U32>()) {
AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
deco->source());
@@ -1037,6 +1093,15 @@
default:
break;
}
+
+ if (is_stage_mismatch) {
+ AddError(deco_to_str(deco) + " cannot be used in " +
+ (is_input ? "input of " : "output of ") + stage_name.str() +
+ " pipeline stage",
+ deco->source());
+ return false;
+ }
+
return true;
}
@@ -1070,12 +1135,9 @@
return false;
}
- auto stage_deco_count = 0;
auto workgroup_deco_count = 0;
for (auto* deco : func->decorations()) {
- if (deco->Is<ast::StageDecoration>()) {
- stage_deco_count++;
- } else if (deco->Is<ast::WorkgroupDecoration>()) {
+ if (deco->Is<ast::WorkgroupDecoration>()) {
workgroup_deco_count++;
if (func->pipeline_stage() != ast::PipelineStage::kCompute) {
AddError(
@@ -1083,7 +1145,8 @@
deco->source());
return false;
}
- } else if (!deco->Is<ast::InternalDecoration>()) {
+ } else if (!deco->IsAnyOf<ast::StageDecoration,
+ ast::InternalDecoration>()) {
AddError("decoration is not valid for functions", deco->source());
return false;
}
@@ -1119,20 +1182,24 @@
for (auto* deco : func->return_type_decorations()) {
if (!func->IsEntryPoint()) {
- AddError("decoration is not valid for function return types",
- deco->source());
+ AddError(
+ "decoration is not valid for non-entry point function return types",
+ deco->source());
return false;
}
- if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
- if (!ValidateBuiltinDecoration(builtin, info->return_type)) {
- return false;
- }
- } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
+ if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
if (!ValidateInterpolateDecoration(interpolate, info->return_type)) {
return false;
}
- } else if (!deco->Is<ast::LocationDecoration>()) {
+ } else if (!deco->IsAnyOf<ast::LocationDecoration, ast::BuiltinDecoration,
+ ast::InternalDecoration>() &&
+ (IsValidationEnabled(
+ info->declaration->decorations(),
+ ast::DisabledValidation::kEntryPointParameter) &&
+ IsValidationEnabled(info->declaration->decorations(),
+ ast::DisabledValidation::
+ kIgnoreAtomicFunctionParameter))) {
AddError("decoration is not valid for entry point return types",
deco->source());
return false;
@@ -1192,6 +1259,12 @@
}
builtins.emplace(builtin->value());
+ if (!ValidateBuiltinDecoration(builtin, ty,
+ /* is_input */ param_or_ret ==
+ ParamOrRetType::kParameter)) {
+ return false;
+ }
+
} else if (auto* location = deco->As<ast::LocationDecoration>()) {
if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", deco->source());
@@ -1409,7 +1482,6 @@
return false;
}
- // TODO(amaiorano): Validate parameter decorations
for (auto* deco : param->decorations()) {
Mark(deco);
}
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index e8e553f..3306294 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -273,7 +273,8 @@
bool ValidateAtomicUses();
bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
- const sem::Type* storage_type);
+ const sem::Type* storage_type,
+ const bool is_input = true);
bool ValidateCallStatement(ast::CallStatement* stmt);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index fee3f57..32b01d2 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -85,9 +85,8 @@
// fn f2(){ dst = wg; }
// fn f1() { f2(); }
// [[stage(fragment)]]
- // fn f0() -> [[builtin(position)]] vec4<f32> {
+ // fn f0() {
// f1();
- // return dst;
//}
Global(Source{{1, 2}}, "wg", ty.vec4<f32>(), ast::StorageClass::kWorkgroup);
@@ -97,10 +96,9 @@
Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(), {stmt});
Func(Source{{7, 8}}, "f1", ast::VariableList{}, ty.void_(),
{Ignore(Call("f2"))});
- Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4<f32>(),
- {Ignore(Call("f1")), Return(Expr("dst"))},
- ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
- ast::DecorationList{Builtin(ast::Builtin::kPosition)});
+ Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.void_(),
+ {Ignore(Call("f1"))},
+ ast::DecorationList{Stage(ast::PipelineStage::kFragment)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index ebadbb4..01cc199 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -296,7 +296,7 @@
TEST_F(BuilderTest, SampleIndex_SampleRateShadingCapability) {
Func("main",
{Param("sample_index", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})},
- ty.void_(), {}, {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
+ ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)});
spirv::Builder& b = SanitizeAndBuild();