Resolver: Enforce matrix constructor type rules
Added enforcement for matrix constructor type rules according to the
table in https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
Fixed: tint:633
Change-Id: I97fc7f558f04780ed03252d94c071af3e0e07e26
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45020
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Arman Uguray <armansito@chromium.org>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 6e16b2e..d0e24d9 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -63,6 +63,13 @@
T old_value_;
};
+// Helper function that returns the range union of two source locations. The
+// `start` and `end` locations are assumed to refer to the same source file.
+Source CombineSourceRange(const Source& start, const Source& end) {
+ return Source(Source::Range(start.range.begin, end.range.end),
+ start.file_path, start.file_content);
+}
+
} // namespace
Resolver::Resolver(ProgramBuilder* builder)
@@ -572,9 +579,11 @@
// obey the constructor type rules laid out in
// https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
if (auto* vec_type = type_ctor->type()->As<type::Vector>()) {
- return VectorConstructor(*vec_type, type_ctor->values());
+ return VectorConstructor(vec_type, type_ctor->values());
}
- // TODO(crbug.com/tint/633): Validate matrix constructor
+ if (auto* mat_type = type_ctor->type()->As<type::Matrix>()) {
+ return MatrixConstructor(mat_type, type_ctor->values());
+ }
// TODO(crbug.com/tint/634): Validate array constructor
} else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
SetType(expr, scalar_ctor->literal()->type());
@@ -584,9 +593,9 @@
return true;
}
-bool Resolver::VectorConstructor(const type::Vector& vec_type,
+bool Resolver::VectorConstructor(const type::Vector* vec_type,
const ast::ExpressionList& values) {
- type::Type* elem_type = vec_type.type()->UnwrapAll();
+ type::Type* elem_type = vec_type->type()->UnwrapAll();
size_t value_cardinality_sum = 0;
for (auto* value : values) {
type::Type* value_type = TypeOf(value)->UnwrapAll();
@@ -635,26 +644,63 @@
// A correct vector constructor must either be a zero-value expression
// or the number of components of all constructor arguments must add up
// to the vector cardinality.
- if (value_cardinality_sum > 0 && value_cardinality_sum != vec_type.size()) {
+ if (value_cardinality_sum > 0 && value_cardinality_sum != vec_type->size()) {
if (values.empty()) {
TINT_ICE(diagnostics_)
<< "constructor arguments expected to be non-empty!";
}
const Source& values_start = values[0]->source();
const Source& values_end = values[values.size() - 1]->source();
- const Source src(
- Source::Range(values_start.range.begin, values_end.range.end),
- values_start.file_path, values_start.file_content);
diagnostics_.add_error(
"attempted to construct '" +
- vec_type.FriendlyName(builder_->Symbols()) + "' with " +
+ vec_type->FriendlyName(builder_->Symbols()) + "' with " +
std::to_string(value_cardinality_sum) + " component(s)",
- src);
+ CombineSourceRange(values_start, values_end));
return false;
}
return true;
}
+bool Resolver::MatrixConstructor(const type::Matrix* matrix_type,
+ const ast::ExpressionList& values) {
+ // Zero Value expression
+ if (values.empty()) {
+ return true;
+ }
+
+ type::Type* elem_type = matrix_type->type()->UnwrapAll();
+ if (matrix_type->columns() != values.size()) {
+ const Source& values_start = values[0]->source();
+ const Source& values_end = values[values.size() - 1]->source();
+ diagnostics_.add_error(
+ "expected " + std::to_string(matrix_type->columns()) + " '" +
+ VectorPretty(matrix_type->rows(), elem_type) + "' arguments in '" +
+ matrix_type->FriendlyName(builder_->Symbols()) +
+ "' constructor, found " + std::to_string(values.size()),
+ CombineSourceRange(values_start, values_end));
+ return false;
+ }
+
+ for (auto* value : values) {
+ type::Type* value_type = TypeOf(value)->UnwrapAll();
+ auto* value_vec = value_type->As<type::Vector>();
+
+ if (!value_vec || value_vec->size() != matrix_type->rows() ||
+ elem_type != value_vec->type()->UnwrapAll()) {
+ diagnostics_.add_error(
+ "expected argument type '" +
+ VectorPretty(matrix_type->rows(), elem_type) + "' in '" +
+ matrix_type->FriendlyName(builder_->Symbols()) +
+ "' constructor, found '" +
+ value_type->FriendlyName(builder_->Symbols()) + "'",
+ value->source());
+ return false;
+ }
+ }
+
+ return true;
+}
+
bool Resolver::Identifier(ast::IdentifierExpression* expr) {
auto symbol = expr->symbol();
VariableInfo* var;
@@ -1501,6 +1547,11 @@
return callback();
}
+std::string Resolver::VectorPretty(uint32_t size, type::Type* element_type) {
+ type::Vector vec_type(element_type, size);
+ return vec_type.FriendlyName(builder_->Symbols());
+}
+
Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
: declaration(decl), storage_class(decl->declared_storage_class()) {}
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 7bb56b6..63e5850 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -194,7 +194,9 @@
bool Call(ast::CallExpression*);
bool CaseStatement(ast::CaseStatement*);
bool Constructor(ast::ConstructorExpression*);
- bool VectorConstructor(const type::Vector& vec_type,
+ bool VectorConstructor(const type::Vector* vec_type,
+ const ast::ExpressionList& values);
+ bool MatrixConstructor(const type::Matrix* matrix_type,
const ast::ExpressionList& values);
bool Expression(ast::Expression*);
bool Expressions(const ast::ExpressionList&);
@@ -247,6 +249,13 @@
template <typename F>
bool BlockScope(BlockInfo::Type type, F&& callback);
+ /// Returns a human-readable string representation of the vector type name
+ /// with the given parameters.
+ /// @param size the vector dimension
+ /// @param element_type scalar vector sub-element type
+ /// @return pretty string representation
+ std::string VectorPretty(uint32_t size, type::Type* element_type);
+
ProgramBuilder* const builder_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
diag::List diagnostics_;
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index 9dc2e0d..8cc4093 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -1664,6 +1664,357 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
+struct MatrixDimensions {
+ uint32_t rows;
+ uint32_t columns;
+};
+
+std::string MatrixStr(const MatrixDimensions& dimensions,
+ std::string subtype = "f32") {
+ return "mat" + std::to_string(dimensions.columns) + "x" +
+ std::to_string(dimensions.rows) + "<" + subtype + ">";
+}
+
+std::string VecStr(uint32_t dimensions, std::string subtype = "f32") {
+ return "vec" + std::to_string(dimensions) + "<" + subtype + ">";
+}
+
+using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
+ // matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments
+
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns - 1; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_type, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:1 error: expected " + std::to_string(param.columns) + " '" +
+ VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
+ "' constructor, found " + std::to_string(param.columns - 1));
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooManyArguments) {
+ // matNxM<f32>(vecM<f32>(), ...); with N + 1 arguments
+
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns + 1; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_type, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:1 error: expected " + std::to_string(param.columns) + " '" +
+ VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
+ "' constructor, found " + std::to_string(param.columns + 1));
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_Error_InvalidArgumentType) {
+ // matNxM<f32>(1.0, 1.0, ...); N arguments
+
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(create<ast::ScalarConstructorExpression>(Source{{12, i}},
+ Literal(1.0f)));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
+ VecStr(param.rows) + "' in '" + MatrixStr(param) +
+ "' constructor, found 'f32'");
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_Constructor_Error_TooFewRowsInVectorArgument) {
+ // matNxM<f32>(vecM<f32>(),...,vecM-1<f32>());
+
+ const auto param = GetParam();
+
+ // Skip the test if parameters would have resuled in an invalid vec1 type.
+ if (param.rows == 2) {
+ return;
+ }
+
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* valid_vec_type = create<type::Vector>(ty.f32(), param.rows);
+ auto* invalid_vec_type = create<type::Vector>(ty.f32(), param.rows - 1);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns - 1; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, valid_vec_type, ExprList()));
+ }
+ const size_t kInvalidLoc = 2 * (param.columns - 1);
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, kInvalidLoc}}, invalid_vec_type, ExprList()));
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
+ " error: expected argument type '" +
+ VecStr(param.rows) + "' in '" + MatrixStr(param) +
+ "' constructor, found '" +
+ VecStr(param.rows - 1) + "'");
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_Constructor_Error_TooManyRowsInVectorArgument) {
+ // matNxM<f32>(vecM<f32>(),...,vecM+1<f32>());
+
+ const auto param = GetParam();
+
+ // Skip the test if parameters would have resuled in an invalid vec5 type.
+ if (param.rows == 4) {
+ return;
+ }
+
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* valid_vec_type = create<type::Vector>(ty.f32(), param.rows);
+ auto* invalid_vec_type = create<type::Vector>(ty.f32(), param.rows + 1);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns - 1; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, valid_vec_type, ExprList()));
+ }
+ const size_t kInvalidLoc = 2 * (param.columns - 1);
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, kInvalidLoc}}, invalid_vec_type, ExprList()));
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
+ " error: expected argument type '" +
+ VecStr(param.rows) + "' in '" + MatrixStr(param) +
+ "' constructor, found '" +
+ VecStr(param.rows + 1) + "'");
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_Constructor_Error_ArgumentVectorElementTypeMismatch) {
+ // matNxM<f32>(vecM<u32>(), ...); with N arguments
+
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* vec_type = create<type::Vector>(ty.u32(), param.rows);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_type, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
+ VecStr(param.rows) + "' in '" + MatrixStr(param) +
+ "' constructor, found '" +
+ VecStr(param.rows, "u32") + "'");
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
+ // matNxM<f32>();
+
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* tc = create<ast::TypeConstructorExpression>(Source{{12, 40}},
+ matrix_type, ExprList());
+ WrapInFunction(tc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_WithArguments_Success) {
+ // matNxM<f32>(vecM<f32>(), ...); with N arguments
+
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_type, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
+ // matNxM<Float32>(vecM<u32>(), ...); with N arguments
+
+ const auto param = GetParam();
+ auto* f32_alias = ty.alias("Float32", ty.f32());
+ auto* matrix_type =
+ create<type::Matrix>(f32_alias, param.rows, param.columns);
+ auto* vec_type = create<type::Vector>(ty.u32(), param.rows);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_type, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:1 error: expected argument type '" + VecStr(param.rows) +
+ "' in '" + MatrixStr(param, "Float32") +
+ "' constructor, found '" + VecStr(param.rows, "u32") + "'");
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
+ // matNxM<Float32>(vecM<f32>(), ...); with N arguments
+
+ const auto param = GetParam();
+ auto* f32_alias = ty.alias("Float32", ty.f32());
+ auto* matrix_type =
+ create<type::Matrix>(f32_alias, param.rows, param.columns);
+ auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_type, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverValidationTest, Expr_MatrixConstructor_ArgumentTypeAlias_Error) {
+ auto* vec2_alias = ty.alias("VectorUnsigned2", ty.vec2<u32>());
+ auto* tc = mat2x2<f32>(create<ast::TypeConstructorExpression>(
+ Source{{12, 34}}, vec2_alias, ExprList()),
+ vec2<f32>());
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: expected argument type 'vec2<f32>' in 'mat2x2<f32>' "
+ "constructor, found 'vec2<u32>'");
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
+ auto* vec_alias = ty.alias("VectorFloat2", vec_type);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_alias, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* f32_alias = ty.alias("UnsignedInt", ty.u32());
+ auto* vec_type = create<type::Vector>(f32_alias, param.rows);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_type, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
+ VecStr(param.rows) + "' in '" + MatrixStr(param) +
+ "' constructor, found '" +
+ VecStr(param.rows, "UnsignedInt") + "'");
+}
+
+TEST_P(MatrixConstructorTest,
+ Expr_Constructor_ArgumentElementTypeAlias_Success) {
+ const auto param = GetParam();
+ auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
+ auto* f32_alias = ty.alias("Float32", ty.f32());
+ auto* vec_type = create<type::Vector>(f32_alias, param.rows);
+
+ ast::ExpressionList args;
+ for (uint32_t i = 1; i <= param.columns; i++) {
+ args.push_back(create<ast::TypeConstructorExpression>(
+ Source{{12, i}}, vec_type, ExprList()));
+ }
+
+ auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
+ std::move(args));
+ WrapInFunction(tc);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverValidationTest,
+ MatrixConstructorTest,
+ testing::Values(MatrixDimensions{2, 2},
+ MatrixDimensions{3, 2},
+ MatrixDimensions{4, 2},
+ MatrixDimensions{2, 3},
+ MatrixDimensions{3, 3},
+ MatrixDimensions{4, 3},
+ MatrixDimensions{2, 4},
+ MatrixDimensions{3, 4},
+ MatrixDimensions{4, 4}));
+
} // namespace
} // namespace resolver
} // namespace tint