resolver: Validate vector types
Fixed: tint:953
Change-Id: I3742680e49894a93db41219e512796ba9bdf036a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56778
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 307f807..8d9f265 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -324,8 +324,12 @@
}
if (auto* t = ty->As<ast::Vector>()) {
if (auto* el = Type(t->type())) {
- return builder_->create<sem::Vector>(const_cast<sem::Type*>(el),
- t->size());
+ if (auto* vector = builder_->create<sem::Vector>(
+ const_cast<sem::Type*>(el), t->size())) {
+ if (ValidateVector(vector, t->source())) {
+ return vector;
+ }
+ }
}
return nullptr;
}
@@ -333,10 +337,10 @@
if (auto* el = Type(t->type())) {
if (auto* column_type = builder_->create<sem::Vector>(
const_cast<sem::Type*>(el), t->rows())) {
- if (auto* matrix_type =
+ if (auto* matrix =
builder_->create<sem::Matrix>(column_type, t->columns())) {
- if (ValidateMatrix(matrix_type, t->source())) {
- return matrix_type;
+ if (ValidateMatrix(matrix, t->source())) {
+ return matrix;
}
}
}
@@ -2300,14 +2304,22 @@
return true;
}
-bool Resolver::ValidateMatrix(const sem::Matrix* matrix_type,
- const Source& source) {
- if (!matrix_type->is_float_matrix()) {
+bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
+ if (!ty->type()->is_scalar()) {
+ AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
+ source);
+ return false;
+ }
+ return true;
+}
+
+bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
+ if (!ty->is_float_matrix()) {
AddError("matrix element type must be 'f32'", source);
return false;
}
return true;
-} // namespace resolver
+}
bool Resolver::ValidateMatrixConstructor(
const ast::TypeConstructorExpression* ctor,
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index cf26a39..047ccc1 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -280,7 +280,7 @@
bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);
- bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source);
+ bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type);
bool ValidateFunctionParameter(const ast::Function* func,
@@ -300,6 +300,7 @@
const std::string& type_name,
const sem::Type* rhs_type,
const std::string& rhs_type_name);
+ bool ValidateVector(const sem::Vector* ty, const Source& source);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type);
bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 70db243..d231774 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -1356,29 +1356,19 @@
// vec4<f32> foo
// }
// struct A {
- // vec3<struct b> mem
+ // array<b, 3> mem
// }
// var c : A
// c.mem[0].foo.yx
// -> vec2<f32>
//
- // MemberAccessor{
- // MemberAccessor{
- // ArrayAccessor{
- // MemberAccessor{
- // Identifier{c}
- // Identifier{mem}
- // }
- // ScalarConstructor{0}
- // }
- // Identifier{foo}
- // }
- // Identifier{yx}
+ // fn f() {
+ // c.mem[0].foo
// }
//
auto* stB = Structure("B", {Member("foo", ty.vec4<f32>())});
- auto* stA = Structure("A", {Member("mem", ty.vec(ty.Of(stB), 3))});
+ auto* stA = Structure("A", {Member("mem", ty.array(ty.Of(stB), 3))});
Global("c", ty.Of(stA), ast::StorageClass::kPrivate);
auto* mem = MemberAccessor(
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index 68c307d..a0a17c2 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -156,6 +156,9 @@
template <typename T>
using mat4x4 = mat<4, 4, T>;
+template <int N, typename T>
+struct array {};
+
template <typename TO, int ID = 0>
struct alias {};
@@ -384,6 +387,43 @@
}
};
+/// Helper for building array types and expressions
+template <int N, typename T>
+struct DataType<array<N, T>> {
+ /// true as arrays are a composite type
+ static constexpr bool is_composite = true;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST array type
+ static inline ast::Type* AST(ProgramBuilder& b) {
+ return b.ty.array(DataType<T>::AST(b), N);
+ }
+ /// @param b the ProgramBuilder
+ /// @return the semantic array type
+ static inline sem::Type* Sem(ProgramBuilder& b) {
+ return b.create<sem::Array>(DataType<T>::Sem(b), N);
+ }
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value each element in the array will be initialized
+ /// with
+ /// @return a new AST array value expression
+ static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+ return b.Construct(AST(b), ExprArgs(b, elem_value));
+ }
+
+ /// @param b the ProgramBuilder
+ /// @param elem_value the value each element will be initialized with
+ /// @return the list of expressions that are used to construct the array
+ static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
+ int elem_value) {
+ ast::ExpressionList args;
+ for (int i = 0; i < N; i++) {
+ args.emplace_back(DataType<T>::Expr(b, elem_value));
+ }
+ return args;
+ }
+};
+
/// Struct of all creation pointer types
struct CreatePtrs {
/// ast node type create function
diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc
index fd73e05..52925b4 100644
--- a/src/resolver/type_validation_test.cc
+++ b/src/resolver/type_validation_test.cc
@@ -42,6 +42,8 @@
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
+template <int N, typename T>
+using array = builder::array<N, T>;
template <typename T>
using alias = builder::alias<T>;
template <typename T>
@@ -751,6 +753,126 @@
} // namespace StorageTextureTests
+namespace MatrixTests {
+struct Params {
+ uint32_t columns;
+ uint32_t rows;
+ builder::ast_type_func_ptr elem_ty;
+};
+
+template <typename T>
+constexpr Params ParamsFor(uint32_t columns, uint32_t rows) {
+ return Params{columns, rows, DataType<T>::AST};
+}
+
+using ValidMatrixTypes = ResolverTestWithParam<Params>;
+TEST_P(ValidMatrixTypes, Okay) {
+ // var a : matNxM<EL_TY>;
+ auto& params = GetParam();
+ Global("a", ty.mat(params.elem_ty(*this), params.columns, params.rows),
+ ast::StorageClass::kPrivate);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ ValidMatrixTypes,
+ testing::Values(ParamsFor<f32>(2, 2),
+ ParamsFor<f32>(2, 3),
+ ParamsFor<f32>(2, 4),
+ ParamsFor<f32>(3, 2),
+ ParamsFor<f32>(3, 3),
+ ParamsFor<f32>(3, 4),
+ ParamsFor<f32>(4, 2),
+ ParamsFor<f32>(4, 3),
+ ParamsFor<f32>(4, 4),
+ ParamsFor<alias<f32>>(4, 2),
+ ParamsFor<alias<f32>>(4, 3),
+ ParamsFor<alias<f32>>(4, 4)));
+
+using InvalidMatrixElementTypes = ResolverTestWithParam<Params>;
+TEST_P(InvalidMatrixElementTypes, InvalidElementType) {
+ // var a : matNxM<EL_TY>;
+ auto& params = GetParam();
+ Global("a",
+ ty.mat(Source{{12, 34}}, params.elem_ty(*this), params.columns,
+ params.rows),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ InvalidMatrixElementTypes,
+ testing::Values(ParamsFor<bool>(4, 2),
+ ParamsFor<i32>(4, 3),
+ ParamsFor<u32>(4, 4),
+ ParamsFor<vec2<f32>>(2, 2),
+ ParamsFor<vec3<i32>>(2, 3),
+ ParamsFor<vec4<u32>>(2, 4),
+ ParamsFor<mat2x2<f32>>(3, 2),
+ ParamsFor<mat3x3<f32>>(3, 3),
+ ParamsFor<mat4x4<f32>>(3, 4),
+ ParamsFor<array<2, f32>>(4, 2)));
+} // namespace MatrixTests
+
+namespace VectorTests {
+struct Params {
+ uint32_t width;
+ builder::ast_type_func_ptr elem_ty;
+};
+
+template <typename T>
+constexpr Params ParamsFor(uint32_t width) {
+ return Params{width, DataType<T>::AST};
+}
+
+using ValidVectorTypes = ResolverTestWithParam<Params>;
+TEST_P(ValidVectorTypes, Okay) {
+ // var a : vecN<EL_TY>;
+ auto& params = GetParam();
+ Global("a", ty.vec(params.elem_ty(*this), params.width),
+ ast::StorageClass::kPrivate);
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ ValidVectorTypes,
+ testing::Values(ParamsFor<bool>(2),
+ ParamsFor<f32>(2),
+ ParamsFor<i32>(2),
+ ParamsFor<u32>(2),
+ ParamsFor<bool>(3),
+ ParamsFor<f32>(3),
+ ParamsFor<i32>(3),
+ ParamsFor<u32>(3),
+ ParamsFor<bool>(4),
+ ParamsFor<f32>(4),
+ ParamsFor<i32>(4),
+ ParamsFor<u32>(4),
+ ParamsFor<alias<bool>>(4),
+ ParamsFor<alias<f32>>(4),
+ ParamsFor<alias<i32>>(4),
+ ParamsFor<alias<u32>>(4)));
+
+using InvalidVectorElementTypes = ResolverTestWithParam<Params>;
+TEST_P(InvalidVectorElementTypes, InvalidElementType) {
+ // var a : vecN<EL_TY>;
+ auto& params = GetParam();
+ Global("a", ty.vec(Source{{12, 34}}, params.elem_ty(*this), params.width),
+ ast::StorageClass::kPrivate);
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: vector element type must be 'bool', 'f32', 'i32' or 'u32'");
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+ InvalidVectorElementTypes,
+ testing::Values(ParamsFor<vec2<f32>>(2),
+ ParamsFor<vec3<i32>>(2),
+ ParamsFor<vec4<u32>>(2),
+ ParamsFor<mat2x2<f32>>(2),
+ ParamsFor<mat3x3<f32>>(2),
+ ParamsFor<mat4x4<f32>>(2),
+ ParamsFor<array<2, f32>>(2)));
+} // namespace VectorTests
+
} // namespace
} // namespace resolver
} // namespace tint
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index c688ad7..fee3f57 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -2007,17 +2007,6 @@
using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
-TEST_F(MatrixConstructorTest, Expr_Constructor_Matrix_NotF32) {
- // m2x2<i32>()
- SetSource(Source::Location({12, 34}));
- auto* tc = mat2x2<i32>(
- create<ast::TypeConstructorExpression>(ty.mat2x2<i32>(), ExprList()));
- WrapInFunction(tc);
-
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
-}
-
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
// matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments