[ir] Validate matrix constructor arguments
Use the intrinsics table to check that all matrix construct
instructions have valid type signatures.
Fix one invalid test and a bug in a SPIR-V reader transform.
Fixed: 427965903, 433565172
Change-Id: I8c21eee81022e8652bc99742de18a057a973feb9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/255014
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index da472ad..3934eb3 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -3153,8 +3153,16 @@
// TODO(crbug.com/427964608): This needs special handling as Element() produces nullptr.
} else if (result_type->Is<core::type::Vector>()) {
// TODO(crbug.com/427964205): This needs special handling as there are many cases.
- } else if (result_type->Is<core::type::Matrix>()) {
- // TODO(crbug.com/427965903): This needs special handling as there are many cases.
+ } else if (auto* mat = result_type->As<core::type::Matrix>()) {
+ auto table = intrinsic::Table<intrinsic::Dialect>(type_mgr_, symbols_);
+ auto ctor_conv = intrinsic::MatrixCtorConv(mat->Columns(), mat->Rows());
+ auto arg_types = Transform<8>(args, [&](auto* v) { return v->Type(); });
+ auto match = table.Lookup(ctor_conv, Vector{mat->Type()}, std::move(arg_types),
+ core::EvaluationStage::kConstant);
+ if (match != Success) {
+ AddError(construct) << "no matching overload for " << mat->FriendlyName()
+ << " constructor";
+ }
} else if (result_type->Is<core::type::Array>()) {
check_args_match_elements();
} else if (auto* str = As<core::type::Struct>(result_type)) {
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index e4a38aa..dc06f9a 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -210,6 +210,69 @@
)")) << res.Failure();
}
+TEST_F(IR_ValidatorTest, Construct_Matrix_NoArgs) {
+ auto* f = b.Function("f", ty.void_());
+ b.Append(f->Block(), [&] {
+ b.Construct(ty.mat2x2<f32>());
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Construct_Matrix_Scalar) {
+ auto* f = b.Function("f", ty.void_());
+ b.Append(f->Block(), [&] {
+ b.Construct(ty.mat2x2<f32>(), 1_f, 2_f, 3_f, 4_f);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Construct_Matrix_ColumnVectors) {
+ auto* f = b.Function("f", ty.void_());
+ b.Append(f->Block(), [&] {
+ auto* v1 = b.Composite(ty.vec2<f32>(), 1_f, 2_f);
+ auto* v2 = b.Composite(ty.vec2<f32>(), 3_f, 4_f);
+ b.Construct(ty.mat2x2<f32>(), v1, v2);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_EQ(res, Success) << res.Failure();
+}
+
+TEST_F(IR_ValidatorTest, Construct_Matrix_MixedScalarVector) {
+ auto* f = b.Function("f", ty.void_());
+ b.Append(f->Block(), [&] {
+ b.Construct(ty.mat2x2<f32>(), 1_f, b.Composite(ty.vec2<f32>(), 2_f, 3_f));
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_THAT(
+ res.Failure().reason,
+ testing::HasSubstr("error: construct: no matching overload for mat2x2<f32> constructor"));
+}
+
+TEST_F(IR_ValidatorTest, Construct_Matrix_Scalar_WrongType) {
+ auto* f = b.Function("f", ty.void_());
+ b.Append(f->Block(), [&] {
+ b.Construct(ty.mat2x2<f32>(), 1_f, 2_f, 3_h, 4_f);
+ b.Return(f);
+ });
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_THAT(
+ res.Failure().reason,
+ testing::HasSubstr("error: construct: no matching overload for mat2x2<f32> constructor"));
+}
+
TEST_F(IR_ValidatorTest, Construct_Struct_ZeroValue) {
auto* str_ty = ty.Struct(mod.symbols.New("MyStruct"), {
{mod.symbols.New("a"), ty.i32()},
diff --git a/src/tint/lang/hlsl/writer/construct_test.cc b/src/tint/lang/hlsl/writer/construct_test.cc
index 1c6e11c..55323c2 100644
--- a/src/tint/lang/hlsl/writer/construct_test.cc
+++ b/src/tint/lang/hlsl/writer/construct_test.cc
@@ -141,7 +141,7 @@
TEST_F(HlslWriterTest, ConstructMatrix) {
auto* f = b.Function("a", ty.mat2x2<f32>());
b.Append(f->Block(), [&] {
- auto* v = b.Var("v", 2_f);
+ auto* v = b.Let("v", 2_f);
b.Return(f, b.Construct(ty.mat2x2<f32>(), v, v, v, v));
});
@@ -149,7 +149,8 @@
EXPECT_EQ(output_.hlsl, R"(
float2x2 a() {
float v = 2.0f;
- return float2x2(v, v, v, v);
+ float2 v_1 = float2(v, v);
+ return float2x2(v_1, float2(v, v));
}
[numthreads(1, 1, 1)]
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index 2b7046d..9f1caa4 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -1287,10 +1287,10 @@
// a * fkgj - b * ekgi + c * ejfi
auto* r_33 = sub_add_mul3(ma, fkgj, mb, ekgi, mc, ejfi);
- auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02, r_03);
- auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12, r_13);
- auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22, r_23);
- auto* r4 = b.Construct(ty.vec3(elem_ty), r_30, r_31, r_32, r_33);
+ auto* r1 = b.Construct(ty.vec4(elem_ty), r_00, r_01, r_02, r_03);
+ auto* r2 = b.Construct(ty.vec4(elem_ty), r_10, r_11, r_12, r_13);
+ auto* r3 = b.Construct(ty.vec4(elem_ty), r_20, r_21, r_22, r_23);
+ auto* r4 = b.Construct(ty.vec4(elem_ty), r_30, r_31, r_32, r_33);
auto* m = b.Construct(mat_ty, r1, r2, r3, r4);
auto* inv = b.Multiply(mat_ty, inv_det, m);
diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc
index 684703e..26e9340 100644
--- a/src/tint/lang/spirv/reader/lower/builtins_test.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc
@@ -532,10 +532,10 @@
%155:f32 = mul %7, %74
%156:f32 = sub %153, %154
%157:f32 = add %156, %155
- %158:vec3<f32> = construct %80, %85, %90, %95
- %159:vec3<f32> = construct %102, %107, %112, %117
- %160:vec3<f32> = construct %122, %127, %132, %137
- %161:vec3<f32> = construct %142, %147, %152, %157
+ %158:vec4<f32> = construct %80, %85, %90, %95
+ %159:vec4<f32> = construct %102, %107, %112, %117
+ %160:vec4<f32> = construct %122, %127, %132, %137
+ %161:vec4<f32> = construct %142, %147, %152, %157
%162:mat4x4<f32> = construct %158, %159, %160, %161
%163:mat4x4<f32> = mul %4, %162
ret
@@ -728,10 +728,10 @@
%155:f16 = mul %7, %74
%156:f16 = sub %153, %154
%157:f16 = add %156, %155
- %158:vec3<f16> = construct %80, %85, %90, %95
- %159:vec3<f16> = construct %102, %107, %112, %117
- %160:vec3<f16> = construct %122, %127, %132, %137
- %161:vec3<f16> = construct %142, %147, %152, %157
+ %158:vec4<f16> = construct %80, %85, %90, %95
+ %159:vec4<f16> = construct %102, %107, %112, %117
+ %160:vec4<f16> = construct %122, %127, %132, %137
+ %161:vec4<f16> = construct %142, %147, %152, %157
%162:mat4x4<f16> = construct %158, %159, %160, %161
%163:mat4x4<f16> = mul %4, %162
ret