Test that entry point IO attributes are of valid types
BUG=tint:773
Change-Id: I94e8624647c645efe7ed558caa3d3bd05dd72f63
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50260
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/resolver/entry_point_validation_test.cc b/src/resolver/entry_point_validation_test.cc
index f872844..3a7eef3 100644
--- a/src/resolver/entry_point_validation_test.cc
+++ b/src/resolver/entry_point_validation_test.cc
@@ -23,13 +23,14 @@
#include "gmock/gmock.h"
namespace tint {
+namespace resolver {
namespace {
-class ResolverEntryPointValidationTest : public resolver::TestHelper,
+class ResolverEntryPointValidationTest : public TestHelper,
public testing::Test {};
TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Location) {
- // [[stage(vertex)]]
+ // [[stage(fragment)]]
// fn main() -> [[location(0)]] f32 { return 1.0; }
Func(Source{{12, 34}}, "main", {}, ty.f32(), {Return(1.0f)},
{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
@@ -514,5 +515,118 @@
"in its return type");
}
+namespace TypeValidationTests {
+struct Params {
+ create_ast_type_func_ptr create_ast_type;
+ bool is_valid;
+};
+
+using TypeValidationTest = resolver::ResolverTestWithParam<Params>;
+
+static constexpr Params cases[] = {
+ {ast_f32, true},
+ {ast_i32, true},
+ {ast_u32, true},
+ {ast_bool, false},
+ {ast_vec2<ast_f32>, true},
+ {ast_vec3<ast_f32>, true},
+ {ast_vec4<ast_f32>, true},
+ {ast_mat2x2<ast_f32>, false},
+ {ast_mat2x2<ast_i32>, false},
+ {ast_mat2x2<ast_u32>, false},
+ {ast_mat2x2<ast_bool>, false},
+ {ast_mat3x3<ast_f32>, false},
+ {ast_mat3x3<ast_i32>, false},
+ {ast_mat3x3<ast_u32>, false},
+ {ast_mat3x3<ast_bool>, false},
+ {ast_mat4x4<ast_f32>, false},
+ {ast_mat4x4<ast_i32>, false},
+ {ast_mat4x4<ast_u32>, false},
+ {ast_mat4x4<ast_bool>, false},
+ {ast_alias<ast_f32>, true},
+ {ast_alias<ast_i32>, true},
+ {ast_alias<ast_u32>, true},
+ {ast_alias<ast_bool>, false},
+};
+
+TEST_P(TypeValidationTest, BareInputs) {
+ // [[stage(fragment)]]
+ // fn main([[location(0)]] a : *) {}
+ auto params = GetParam();
+ auto* a = Param("a", params.create_ast_type(ty), {Location(0)});
+ Func(Source{{12, 34}}, "main", {a}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ }
+}
+
+TEST_P(TypeValidationTest, StructInputs) {
+ // struct Input {
+ // [[location(0)]] a : *;
+ // };
+ // [[stage(fragment)]]
+ // fn main(a : Input) {}
+ auto params = GetParam();
+ auto* input = Structure(
+ "Input", {Member("a", params.create_ast_type(ty), {Location(0)})});
+ auto* a = Param("a", input, {});
+ Func(Source{{12, 34}}, "main", {a}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ }
+}
+
+TEST_P(TypeValidationTest, BareOutputs) {
+ // [[stage(fragment)]]
+ // fn main() -> [[location(0)]] * {
+ // return *();
+ // }
+ auto params = GetParam();
+ Func(Source{{12, 34}}, "main", {}, params.create_ast_type(ty),
+ {Return(Construct(params.create_ast_type(ty)))},
+ {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ }
+}
+
+TEST_P(TypeValidationTest, StructOutputs) {
+ // struct Output {
+ // [[location(0)]] a : *;
+ // };
+ // [[stage(fragment)]]
+ // fn main() -> Output {
+ // return Output();
+ // }
+ auto params = GetParam();
+ auto* output = Structure(
+ "Output", {Member("a", params.create_ast_type(ty), {Location(0)})});
+ Func(Source{{12, 34}}, "main", {}, output, {Return(Construct(output))},
+ {Stage(ast::PipelineStage::kFragment)});
+
+ if (params.is_valid) {
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+ } else {
+ EXPECT_FALSE(r()->Resolve());
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverEntryPointValidationTest,
+ TypeValidationTest,
+ testing::ValuesIn(cases));
+
+} // namespace TypeValidationTests
+
} // namespace
+} // namespace resolver
} // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index a155ff4..32273e2 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -832,6 +832,19 @@
diagnostics_.add_error(err, source);
return false;
}
+
+ // Check that all user defined attributes are numeric scalars, vectors
+ // of numeric scalars.
+ // Testing for being a struct is handled by the if portion above.
+ if (!pipeline_io_attribute->Is<ast::BuiltinDecoration>()) {
+ if (!Canonical(ty)->is_numeric_scalar_or_vector()) {
+ diagnostics_.add_error(
+ "User defined entry point IO types must be a numeric scalar, "
+ "a numeric vector, or a structure",
+ source);
+ return false;
+ }
+ }
}
return true;
diff --git a/src/sem/type.cc b/src/sem/type.cc
index 947b149..589688f 100644
--- a/src/sem/type.cc
+++ b/src/sem/type.cc
@@ -132,6 +132,15 @@
return Is<Bool>() || is_bool_vector();
}
+bool Type::is_numeric_vector() const {
+ return Is<Vector>(
+ [](const Vector* v) { return v->type()->is_numeric_scalar(); });
+}
+
+bool Type::is_numeric_scalar_or_vector() const {
+ return is_numeric_scalar() || is_numeric_vector();
+}
+
bool Type::is_handle() const {
return IsAnyOf<Sampler, Texture>();
}
diff --git a/src/sem/type.h b/src/sem/type.h
index 49da00c..ca7bdd4 100644
--- a/src/sem/type.h
+++ b/src/sem/type.h
@@ -125,6 +125,10 @@
bool is_bool_vector() const;
/// @returns true if this type is boolean scalar or vector
bool is_bool_scalar_or_vector() const;
+ /// @returns true if this type is a numeric vector
+ bool is_numeric_vector() const;
+ /// @returns true if this type is a numeric scale or vector
+ bool is_numeric_scalar_or_vector() const;
/// @returns true if this type is a handle type
bool is_handle() const;
diff --git a/src/transform/canonicalize_entry_point_io_test.cc b/src/transform/canonicalize_entry_point_io_test.cc
index 2f4b4ab..876c21d 100644
--- a/src/transform/canonicalize_entry_point_io_test.cc
+++ b/src/transform/canonicalize_entry_point_io_test.cc
@@ -508,13 +508,13 @@
struct VertexOutput {
[[location(1)]] b : u32;
[[builtin(position)]] pos : vec4<f32>;
- [[location(3)]] d : bool;
+ [[location(3)]] d : u32;
[[location(0)]] a : f32;
[[location(2)]] c : i32;
};
struct FragmentInputExtra {
- [[location(3)]] d : bool;
+ [[location(3)]] d : u32;
[[builtin(position)]] pos : vec4<f32>;
[[location(0)]] a : f32;
};
@@ -536,13 +536,13 @@
struct VertexOutput {
b : u32;
pos : vec4<f32>;
- d : bool;
+ d : u32;
a : f32;
c : i32;
};
struct FragmentInputExtra {
- d : bool;
+ d : u32;
pos : vec4<f32>;
a : f32;
};
@@ -555,7 +555,7 @@
[[location(2)]]
c : i32;
[[location(3)]]
- d : bool;
+ d : u32;
[[builtin(position)]]
pos : vec4<f32>;
};
@@ -574,7 +574,7 @@
[[location(2)]]
c : i32;
[[location(3)]]
- d : bool;
+ d : u32;
[[builtin(position)]]
pos : vec4<f32>;
[[builtin(front_facing)]]