[spirv-reader] Always apply @stride to matrices
When we support row-major matrices, we will not know the natural
matrix stride when processing the MatrixStride decoration. Change the
parser to always apply the @stride() attribute, and move the logic for
dropping it into the DecomposeMatrixStride transform.
Bug: 364267168
Change-Id: Icbfb85aa9d9ee01f4e7813fdd6888b70721de915
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/207417
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
index 9f98364..a50c38c 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
@@ -89,6 +89,7 @@
// Scan the program for all storage and uniform structure matrix members with
// a custom stride attribute. Replace these matrices with an equivalent array,
// and populate the `decomposed` map with the members that have been replaced.
+ bool made_changes = false;
Hashmap<const core::type::StructMember*, MatrixInfo, 8> decomposed;
for (auto* node : src.ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) {
@@ -98,19 +99,41 @@
continue;
}
for (auto* member : str_ty->Members()) {
- auto* matrix = member->Type()->As<core::type::Matrix>();
- if (!matrix) {
- continue;
- }
auto* attr =
ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes);
if (!attr) {
+ // No stride attribute - nothing to do.
continue;
}
+
+ // Get the matrix type, which may be nested inside an array.
+ auto* ty = member->Type();
+ while (auto* arr = ty->As<core::type::Array>()) {
+ ty = arr->ElemType();
+ }
+ auto* matrix = ty->As<core::type::Matrix>();
+ TINT_ASSERT(matrix);
+
+ made_changes = true;
+
uint32_t stride = attr->stride;
if (matrix->ColumnStride() == stride) {
+ // The attribute specifies the natural stride, so just remove the attribute.
+ auto* disable_validation = ast::GetAttribute<ast::DisableValidationAttribute>(
+ member->Declaration()->attributes);
+ TINT_ASSERT(disable_validation->validation ==
+ ast::DisabledValidation::kIgnoreStrideAttribute);
+ ctx.Remove(member->Declaration()->attributes, attr);
+ ctx.Remove(member->Declaration()->attributes, disable_validation);
continue;
}
+
+ if (member->Type()->Is<core::type::Array>()) {
+ b.Diagnostics().AddError(attr->source)
+ << "custom matrix strides not currently supported on array of matrices";
+ return Program(std::move(b));
+ }
+
// We've got ourselves a struct member of a matrix type with a custom
// stride. Replace this with an array of column vectors.
MatrixInfo info{stride, matrix};
@@ -122,7 +145,7 @@
}
}
- if (decomposed.IsEmpty()) {
+ if (!made_changes) {
return SkipTransform;
}
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
index 85e9094..a69ff9d 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
@@ -220,7 +220,6 @@
@size(16)
padding_0 : u32,
/* @offset(16u) */
- @stride(8) @internal(disable_validation__ignore_stride)
m : mat2x2<f32>,
}
@@ -237,6 +236,100 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(DecomposeStridedMatrixTest, ReadUniformArrayOfMatrix_DefaultStride) {
+ // struct S {
+ // @offset(16) @stride(8)
+ // @internal(ignore_stride_attribute)
+ // a : array<array<mat2x2<f32>, 4>, 4>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", Vector{
+ b.Member("m", b.ty.array(b.ty.array(b.ty.mat2x2<f32>(), 4_a), 4_a),
+ Vector{
+ b.MemberOffset(16_u),
+ b.create<ast::StrideAttribute>(8u),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("x", b.ty.array(b.ty.array(b.ty.mat2x2<f32>(), 4_a), 4_a),
+ b.MemberAccessor("s", "m"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16u) */
+ m : array<array<mat2x2<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : array<array<mat2x2<f32>, 4>, 4> = s.m;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformArrayOfMatrix_CustomStride) {
+ // struct S {
+ // @offset(16) @stride(16)
+ // @internal(ignore_stride_attribute)
+ // a : array<array<mat2x2<f32>, 4>, 4>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", Vector{
+ b.Member("m", b.ty.array(b.ty.array(b.ty.mat2x2<f32>(), 4_a), 4_a),
+ Vector{
+ b.MemberOffset(16_u),
+ b.create<ast::StrideAttribute>(16u),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("x", b.ty.array(b.ty.array(b.ty.mat2x2<f32>(), 4_a), 4_a),
+ b.MemberAccessor("s", "m"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(error: custom matrix strides not currently supported on array of matrices)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
// struct S {
// @offset(8) @stride(32)
diff --git a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
index 8e486f5..b79fb8c 100644
--- a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
@@ -517,7 +517,6 @@
<< ShowType(struct_type_id);
break;
}
- uint32_t stride = decoration[1];
auto* ty = member_ty->UnwrapAlias();
while (auto* arr = ty->As<Array>()) {
ty = arr->type->UnwrapAlias();
@@ -527,14 +526,12 @@
Fail() << "MatrixStride cannot be applied to type " << ty->String();
break;
}
- uint32_t natural_stride = (mat->rows == 2) ? 8 : 16;
- if (stride == natural_stride) {
- break; // Decoration matches the natural stride for the matrix
- }
- if (!member_ty->Is<Matrix>()) {
- Fail() << "custom matrix strides not currently supported on array of matrices";
- break;
- }
+
+ // Note: We do not know at this point whether the matrix is laid out as row-major or
+ // column-major, and therefore do not know the "natural" stride. So we add the stride
+ // attribute unconditionally, and let the DecomposeStridedMatrix transform determine if
+ // anything needs to be done.
+
out.Add(create<ast::StrideAttribute>(Source{}, decoration[1]));
out.Add(builder_.ASTNodes().Create<ast::DisableValidationAttribute>(
builder_.ID(), builder_.AllocateNodeID(),
diff --git a/src/tint/lang/spirv/reader/ast_parser/convert_member_decoration_test.cc b/src/tint/lang/spirv/reader/ast_parser/convert_member_decoration_test.cc
index 104f5d7..11e78ea 100644
--- a/src/tint/lang/spirv/reader/ast_parser/convert_member_decoration_test.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/convert_member_decoration_test.cc
@@ -80,7 +80,11 @@
ast_parser::Matrix matrix(&f32, 2, 2);
auto result =
p->ConvertMemberDecoration(1, 1, &matrix, {uint32_t(spv::Decoration::MatrixStride), 8});
- EXPECT_TRUE(result.list.IsEmpty());
+ ASSERT_FALSE(result.list.IsEmpty());
+ EXPECT_TRUE(result.list[0]->Is<ast::StrideAttribute>());
+ auto* stride_deco = result.list[0]->As<ast::StrideAttribute>();
+ ASSERT_NE(stride_deco, nullptr);
+ EXPECT_EQ(stride_deco->stride, 8u);
EXPECT_TRUE(p->error().empty());
}
@@ -106,7 +110,11 @@
ast_parser::Matrix matrix(&f32, 2, 4);
auto result =
p->ConvertMemberDecoration(1, 1, &matrix, {uint32_t(spv::Decoration::MatrixStride), 16});
- EXPECT_TRUE(result.list.IsEmpty());
+ ASSERT_FALSE(result.list.IsEmpty());
+ EXPECT_TRUE(result.list[0]->Is<ast::StrideAttribute>());
+ auto* stride_deco = result.list[0]->As<ast::StrideAttribute>();
+ ASSERT_NE(stride_deco, nullptr);
+ EXPECT_EQ(stride_deco->stride, 16u);
EXPECT_TRUE(p->error().empty());
}
diff --git a/src/tint/lang/spirv/reader/ast_parser/module_var_test.cc b/src/tint/lang/spirv/reader/ast_parser/module_var_test.cc
index e3f24e1..f92a464 100644
--- a/src/tint/lang/spirv/reader/ast_parser/module_var_test.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/module_var_test.cc
@@ -1301,6 +1301,7 @@
const auto module_str = test::ToString(p->program());
EXPECT_THAT(module_str, HasSubstr(R"(struct S {
/* @offset(0) */
+ @stride(8) @internal(disable_validation__ignore_stride)
field0 : mat3x2f,
}
@@ -1308,7 +1309,7 @@
)")) << module_str;
}
-TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_Dropped) {
+TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_ColMajor) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
OpDecorate %myvar DescriptorSet 0
@@ -1332,6 +1333,7 @@
const auto module_str = test::ToString(p->program());
EXPECT_THAT(module_str, HasSubstr(R"(struct S {
/* @offset(0) */
+ @stride(8) @internal(disable_validation__ignore_stride)
field0 : mat3x2f,
}