Validate multiply of invalid vector/matrix sizes
Added tests that test all combos of vec*mat, mat*vec, and mat*mat for 2,
3, and 4 dimensions.
Bug: tint:698
Change-Id: I4a407228261cf8ea2a93bc7077544e5a9244d854
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46660
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/program_builder.h b/src/program_builder.h
index 76c823d..613fb62 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -928,7 +928,7 @@
/// @param rhs the right hand argument to the addition operation
/// @returns a `ast::BinaryExpression` summing the arguments `lhs` and `rhs`
template <typename LHS, typename RHS>
- ast::Expression* Add(LHS&& lhs, RHS&& rhs) {
+ ast::BinaryExpression* Add(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kAdd,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
@@ -938,7 +938,7 @@
/// @param rhs the right hand argument to the subtraction operation
/// @returns a `ast::BinaryExpression` subtracting `rhs` from `lhs`
template <typename LHS, typename RHS>
- ast::Expression* Sub(LHS&& lhs, RHS&& rhs) {
+ ast::BinaryExpression* Sub(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kSubtract,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
@@ -948,17 +948,28 @@
/// @param rhs the right hand argument to the multiplication operation
/// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
template <typename LHS, typename RHS>
- ast::Expression* Mul(LHS&& lhs, RHS&& rhs) {
+ ast::BinaryExpression* Mul(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
+ /// @param source the source information
+ /// @param lhs the left hand argument to the multiplication operation
+ /// @param rhs the right hand argument to the multiplication operation
+ /// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
+ template <typename LHS, typename RHS>
+ ast::BinaryExpression* Mul(const Source& source, LHS&& lhs, RHS&& rhs) {
+ return create<ast::BinaryExpression>(source, ast::BinaryOp::kMultiply,
+ Expr(std::forward<LHS>(lhs)),
+ Expr(std::forward<RHS>(rhs)));
+ }
+
/// @param arr the array argument for the array accessor expression
/// @param idx the index argument for the array accessor expression
/// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx`
template <typename ARR, typename IDX>
- ast::Expression* IndexAccessor(ARR&& arr, IDX&& idx) {
+ ast::ArrayAccessorExpression* IndexAccessor(ARR&& arr, IDX&& idx) {
return create<ast::ArrayAccessorExpression>(Expr(std::forward<ARR>(arr)),
Expr(std::forward<IDX>(idx)));
}
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 1158ad6..0c79dd7 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1040,8 +1040,10 @@
auto* rhs_vec_elem_type =
rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
- const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
- (lhs_vec_elem_type == rhs_vec_elem_type);
+ const bool matching_vec_elem_types =
+ lhs_vec_elem_type && rhs_vec_elem_type &&
+ (lhs_vec_elem_type == rhs_vec_elem_type) &&
+ (lhs_vec->size() == rhs_vec->size());
const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type);
@@ -1106,19 +1108,22 @@
// Vector times matrix
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
- rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
+ rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
+ (lhs_vec->size() == rhs_mat->rows())) {
return true;
}
// Matrix times vector
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
- rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>()) {
+ rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
+ (lhs_mat->columns() == rhs_vec->size())) {
return true;
}
// Matrix times matrix
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
- rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
+ rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
+ (lhs_mat->columns() == rhs_mat->rows())) {
return true;
}
}
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 12d019e..7ee58c6 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -1263,11 +1263,11 @@
Global("lhs", lhs_type, ast::StorageClass::kNone);
Global("rhs", rhs_type, ast::StorageClass::kNone);
- auto* expr = create<ast::BinaryExpression>(
- Source{Source::Location{12, 34}}, params.op, Expr("lhs"), Expr("rhs"));
+ auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, params.op,
+ Expr("lhs"), Expr("rhs"));
WrapInFunction(expr);
- ASSERT_FALSE(r()->Resolve()) << r()->error();
+ ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
@@ -1280,6 +1280,99 @@
Expr_Binary_Test_Invalid,
testing::Combine(testing::ValuesIn(all_valid_cases),
testing::ValuesIn(all_create_type_funcs)));
+
+using Expr_Binary_Test_Invalid_VectorMatrixMultiply =
+ ResolverTestWithParam<std::tuple<bool, uint32_t, uint32_t, uint32_t>>;
+TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
+ bool vec_by_mat = std::get<0>(GetParam());
+ uint32_t vec_size = std::get<1>(GetParam());
+ uint32_t mat_rows = std::get<2>(GetParam());
+ uint32_t mat_cols = std::get<3>(GetParam());
+
+ type::Type* lhs_type;
+ type::Type* rhs_type;
+ type::Type* result_type;
+ bool is_valid_expr;
+
+ if (vec_by_mat) {
+ lhs_type = create<type::Vector>(ty.f32(), vec_size);
+ rhs_type = create<type::Matrix>(ty.f32(), mat_rows, mat_cols);
+ result_type = create<type::Vector>(ty.f32(), mat_cols);
+ is_valid_expr = vec_size == mat_rows;
+ } else {
+ lhs_type = create<type::Matrix>(ty.f32(), mat_rows, mat_cols);
+ rhs_type = create<type::Vector>(ty.f32(), vec_size);
+ result_type = create<type::Vector>(ty.f32(), mat_rows);
+ is_valid_expr = vec_size == mat_cols;
+ }
+
+ Global("lhs", lhs_type, ast::StorageClass::kNone);
+ Global("rhs", rhs_type, ast::StorageClass::kNone);
+
+ auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
+ WrapInFunction(expr);
+
+ if (is_valid_expr) {
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(TypeOf(expr) == result_type);
+ } else {
+ ASSERT_FALSE(r()->Resolve());
+ ASSERT_EQ(r()->error(),
+ "12:34 error: Binary expression operand types are invalid for "
+ "this operation: " +
+ lhs_type->FriendlyName(Symbols()) + " " +
+ FriendlyName(expr->op()) + " " +
+ rhs_type->FriendlyName(Symbols()));
+ }
+}
+auto all_dimension_values = testing::Values(2u, 3u, 4u);
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ Expr_Binary_Test_Invalid_VectorMatrixMultiply,
+ testing::Combine(testing::Values(true, false),
+ all_dimension_values,
+ all_dimension_values,
+ all_dimension_values));
+
+using Expr_Binary_Test_Invalid_MatrixMatrixMultiply =
+ ResolverTestWithParam<std::tuple<uint32_t, uint32_t, uint32_t, uint32_t>>;
+TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
+ uint32_t lhs_mat_rows = std::get<0>(GetParam());
+ uint32_t lhs_mat_cols = std::get<1>(GetParam());
+ uint32_t rhs_mat_rows = std::get<2>(GetParam());
+ uint32_t rhs_mat_cols = std::get<3>(GetParam());
+
+ auto* lhs_type = create<type::Matrix>(ty.f32(), lhs_mat_rows, lhs_mat_cols);
+ auto* rhs_type = create<type::Matrix>(ty.f32(), rhs_mat_rows, rhs_mat_cols);
+ auto* result_type =
+ create<type::Matrix>(ty.f32(), lhs_mat_rows, rhs_mat_cols);
+
+ Global("lhs", lhs_type, ast::StorageClass::kNone);
+ Global("rhs", rhs_type, ast::StorageClass::kNone);
+
+ auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
+ WrapInFunction(expr);
+
+ bool is_valid_expr = lhs_mat_cols == rhs_mat_rows;
+ if (is_valid_expr) {
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(TypeOf(expr) == result_type);
+ } else {
+ ASSERT_FALSE(r()->Resolve());
+ ASSERT_EQ(r()->error(),
+ "12:34 error: Binary expression operand types are invalid for "
+ "this operation: " +
+ lhs_type->FriendlyName(Symbols()) + " " +
+ FriendlyName(expr->op()) + " " +
+ rhs_type->FriendlyName(Symbols()));
+ }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTest,
+ Expr_Binary_Test_Invalid_MatrixMatrixMultiply,
+ testing::Combine(all_dimension_values,
+ all_dimension_values,
+ all_dimension_values,
+ all_dimension_values));
+
} // namespace ExprBinaryTest
using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;