[spirv-reader] Always apply @stride to matrices

When we support row-major matrices, we will not know the natural
matrix stride when processing the MatrixStride decoration. Change the
parser to always apply the @stride() attribute, and move the logic for
dropping it into the DecomposeMatrixStride transform.

Bug: 364267168
Change-Id: Icbfb85aa9d9ee01f4e7813fdd6888b70721de915
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/207417
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
index 9f98364..a50c38c 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
@@ -89,6 +89,7 @@
     // Scan the program for all storage and uniform structure matrix members with
     // a custom stride attribute. Replace these matrices with an equivalent array,
     // and populate the `decomposed` map with the members that have been replaced.
+    bool made_changes = false;
     Hashmap<const core::type::StructMember*, MatrixInfo, 8> decomposed;
     for (auto* node : src.ASTNodes().Objects()) {
         if (auto* str = node->As<ast::Struct>()) {
@@ -98,19 +99,41 @@
                 continue;
             }
             for (auto* member : str_ty->Members()) {
-                auto* matrix = member->Type()->As<core::type::Matrix>();
-                if (!matrix) {
-                    continue;
-                }
                 auto* attr =
                     ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes);
                 if (!attr) {
+                    // No stride attribute - nothing to do.
                     continue;
                 }
+
+                // Get the matrix type, which may be nested inside an array.
+                auto* ty = member->Type();
+                while (auto* arr = ty->As<core::type::Array>()) {
+                    ty = arr->ElemType();
+                }
+                auto* matrix = ty->As<core::type::Matrix>();
+                TINT_ASSERT(matrix);
+
+                made_changes = true;
+
                 uint32_t stride = attr->stride;
                 if (matrix->ColumnStride() == stride) {
+                    // The attribute specifies the natural stride, so just remove the attribute.
+                    auto* disable_validation = ast::GetAttribute<ast::DisableValidationAttribute>(
+                        member->Declaration()->attributes);
+                    TINT_ASSERT(disable_validation->validation ==
+                                ast::DisabledValidation::kIgnoreStrideAttribute);
+                    ctx.Remove(member->Declaration()->attributes, attr);
+                    ctx.Remove(member->Declaration()->attributes, disable_validation);
                     continue;
                 }
+
+                if (member->Type()->Is<core::type::Array>()) {
+                    b.Diagnostics().AddError(attr->source)
+                        << "custom matrix strides not currently supported on array of matrices";
+                    return Program(std::move(b));
+                }
+
                 // We've got ourselves a struct member of a matrix type with a custom
                 // stride. Replace this with an array of column vectors.
                 MatrixInfo info{stride, matrix};
@@ -122,7 +145,7 @@
         }
     }
 
-    if (decomposed.IsEmpty()) {
+    if (!made_changes) {
         return SkipTransform;
     }
 
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
index 85e9094..a69ff9d 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
@@ -220,7 +220,6 @@
   @size(16)
   padding_0 : u32,
   /* @offset(16u) */
-  @stride(8) @internal(disable_validation__ignore_stride)
   m : mat2x2<f32>,
 }
 
@@ -237,6 +236,100 @@
     EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(DecomposeStridedMatrixTest, ReadUniformArrayOfMatrix_DefaultStride) {
+    // struct S {
+    //   @offset(16) @stride(8)
+    //   @internal(ignore_stride_attribute)
+    //   a : array<array<mat2x2<f32>, 4>, 4>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let x : mat2x2<f32> = s.m;
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure(
+        "S", Vector{
+                 b.Member("m", b.ty.array(b.ty.array(b.ty.mat2x2<f32>(), 4_a), 4_a),
+                          Vector{
+                              b.MemberOffset(16_u),
+                              b.create<ast::StrideAttribute>(8u),
+                              b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+                          }),
+             });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("x", b.ty.array(b.ty.array(b.ty.mat2x2<f32>(), 4_a), 4_a),
+                            b.MemberAccessor("s", "m"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16u) */
+  m : array<array<mat2x2<f32>, 4>, 4>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let x : array<array<mat2x2<f32>, 4>, 4> = s.m;
+}
+)";
+
+    auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformArrayOfMatrix_CustomStride) {
+    // struct S {
+    //   @offset(16) @stride(16)
+    //   @internal(ignore_stride_attribute)
+    //   a : array<array<mat2x2<f32>, 4>, 4>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let x : mat2x2<f32> = s.m;
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure(
+        "S", Vector{
+                 b.Member("m", b.ty.array(b.ty.array(b.ty.mat2x2<f32>(), 4_a), 4_a),
+                          Vector{
+                              b.MemberOffset(16_u),
+                              b.create<ast::StrideAttribute>(16u),
+                              b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+                          }),
+             });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("x", b.ty.array(b.ty.array(b.ty.mat2x2<f32>(), 4_a), 4_a),
+                            b.MemberAccessor("s", "m"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(error: custom matrix strides not currently supported on array of matrices)";
+
+    auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
 TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
     // struct S {
     //   @offset(8) @stride(32)
diff --git a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
index 8e486f5..b79fb8c 100644
--- a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
@@ -517,7 +517,6 @@
                        << ShowType(struct_type_id);
                 break;
             }
-            uint32_t stride = decoration[1];
             auto* ty = member_ty->UnwrapAlias();
             while (auto* arr = ty->As<Array>()) {
                 ty = arr->type->UnwrapAlias();
@@ -527,14 +526,12 @@
                 Fail() << "MatrixStride cannot be applied to type " << ty->String();
                 break;
             }
-            uint32_t natural_stride = (mat->rows == 2) ? 8 : 16;
-            if (stride == natural_stride) {
-                break;  // Decoration matches the natural stride for the matrix
-            }
-            if (!member_ty->Is<Matrix>()) {
-                Fail() << "custom matrix strides not currently supported on array of matrices";
-                break;
-            }
+
+            // Note: We do not know at this point whether the matrix is laid out as row-major or
+            // column-major, and therefore do not know the "natural" stride. So we add the stride
+            // attribute unconditionally, and let the DecomposeStridedMatrix transform determine if
+            // anything needs to be done.
+
             out.Add(create<ast::StrideAttribute>(Source{}, decoration[1]));
             out.Add(builder_.ASTNodes().Create<ast::DisableValidationAttribute>(
                 builder_.ID(), builder_.AllocateNodeID(),
diff --git a/src/tint/lang/spirv/reader/ast_parser/convert_member_decoration_test.cc b/src/tint/lang/spirv/reader/ast_parser/convert_member_decoration_test.cc
index 104f5d7..11e78ea 100644
--- a/src/tint/lang/spirv/reader/ast_parser/convert_member_decoration_test.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/convert_member_decoration_test.cc
@@ -80,7 +80,11 @@
     ast_parser::Matrix matrix(&f32, 2, 2);
     auto result =
         p->ConvertMemberDecoration(1, 1, &matrix, {uint32_t(spv::Decoration::MatrixStride), 8});
-    EXPECT_TRUE(result.list.IsEmpty());
+    ASSERT_FALSE(result.list.IsEmpty());
+    EXPECT_TRUE(result.list[0]->Is<ast::StrideAttribute>());
+    auto* stride_deco = result.list[0]->As<ast::StrideAttribute>();
+    ASSERT_NE(stride_deco, nullptr);
+    EXPECT_EQ(stride_deco->stride, 8u);
     EXPECT_TRUE(p->error().empty());
 }
 
@@ -106,7 +110,11 @@
     ast_parser::Matrix matrix(&f32, 2, 4);
     auto result =
         p->ConvertMemberDecoration(1, 1, &matrix, {uint32_t(spv::Decoration::MatrixStride), 16});
-    EXPECT_TRUE(result.list.IsEmpty());
+    ASSERT_FALSE(result.list.IsEmpty());
+    EXPECT_TRUE(result.list[0]->Is<ast::StrideAttribute>());
+    auto* stride_deco = result.list[0]->As<ast::StrideAttribute>();
+    ASSERT_NE(stride_deco, nullptr);
+    EXPECT_EQ(stride_deco->stride, 16u);
     EXPECT_TRUE(p->error().empty());
 }
 
diff --git a/src/tint/lang/spirv/reader/ast_parser/module_var_test.cc b/src/tint/lang/spirv/reader/ast_parser/module_var_test.cc
index e3f24e1..f92a464 100644
--- a/src/tint/lang/spirv/reader/ast_parser/module_var_test.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/module_var_test.cc
@@ -1301,6 +1301,7 @@
     const auto module_str = test::ToString(p->program());
     EXPECT_THAT(module_str, HasSubstr(R"(struct S {
   /* @offset(0) */
+  @stride(8) @internal(disable_validation__ignore_stride)
   field0 : mat3x2f,
 }
 
@@ -1308,7 +1309,7 @@
 )")) << module_str;
 }
 
-TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_Dropped) {
+TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_ColMajor) {
     auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
      OpName %myvar "myvar"
      OpDecorate %myvar DescriptorSet 0
@@ -1332,6 +1333,7 @@
     const auto module_str = test::ToString(p->program());
     EXPECT_THAT(module_str, HasSubstr(R"(struct S {
   /* @offset(0) */
+  @stride(8) @internal(disable_validation__ignore_stride)
   field0 : mat3x2f,
 }