[spirv-reader] Support strided f16 matrices

The DecomposeStridedMatrix transform was hardcoded to create f32
arrays, so just update it to use f16 if needed and add some tests.

Bug: 377728743
Change-Id: I882047bcb7ee855053718076a28661356201e9ca
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/215614
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
index a50c38c..8509871 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
@@ -57,7 +57,15 @@
 
     /// @returns the identifier of an array that holds an vector column for each row of the matrix.
     ast::Type array(ast::Builder* b) const {
-        return b->ty.array(b->ty.vec<f32>(matrix->Rows()), u32(matrix->Columns()),
+        ast::Type col_type;
+        if (matrix->Type()->Is<core::type::F32>()) {
+            col_type = b->ty.vec<f32>(matrix->Rows());
+        } else if (matrix->Type()->Is<core::type::F16>()) {
+            col_type = b->ty.vec<f16>(matrix->Rows());
+        } else {
+            TINT_UNREACHABLE();
+        }
+        return b->ty.array(col_type, u32(matrix->Columns()),
                            Vector{
                                b->Stride(stride),
                            });
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
index a69ff9d..e0dc2c8 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
@@ -637,6 +637,305 @@
     EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_F16) {
+    // enable f16;
+    // struct S {
+    //   @offset(16) @stride(32)
+    //   @internal(ignore_stride_attribute)
+    //   m : mat2x2<f16>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let x : mat2x2<f16> = s.m;
+    // }
+    ProgramBuilder b;
+    b.Enable(wgsl::Extension::kF16);
+    auto* S = b.Structure(
+        "S", Vector{
+                 b.Member("m", b.ty.mat2x2<f16>(),
+                          Vector{
+                              b.MemberOffset(16_u),
+                              b.create<ast::StrideAttribute>(32u),
+                              b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+                          }),
+             });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("x", b.ty.mat2x2<f16>(), b.MemberAccessor("s", "m"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+enable f16;
+
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16) */
+  m : @stride(32) array<vec2<f16>, 2u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f16>, 2u>) -> mat2x2<f16> {
+  return mat2x2<f16>(arr[0u], arr[1u]);
+}
+
+@compute @workgroup_size(1i)
+fn f() {
+  let x : mat2x2<f16> = arr_to_mat2x2_stride_32(s.m);
+}
+)";
+
+    auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn_F16) {
+    // enable f16;
+    // struct S {
+    //   @offset(16) @stride(32)
+    //   @internal(ignore_stride_attribute)
+    //   m : mat2x2<f16>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let x : vec2<f32> = s.m[1];
+    // }
+    ProgramBuilder b;
+    b.Enable(wgsl::Extension::kF16);
+    auto* S = b.Structure(
+        "S", Vector{
+                 b.Member("m", b.ty.mat2x2<f16>(),
+                          Vector{
+                              b.MemberOffset(16_u),
+                              b.create<ast::StrideAttribute>(32u),
+                              b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+                          }),
+             });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func(
+        "f", tint::Empty, b.ty.void_(),
+        Vector{
+            b.Decl(b.Let("x", b.ty.vec2<f16>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
+        },
+        Vector{
+            b.Stage(ast::PipelineStage::kCompute),
+            b.WorkgroupSize(1_i),
+        });
+
+    auto* expect = R"(
+enable f16;
+
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16) */
+  m : @stride(32) array<vec2<f16>, 2u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let x : vec2<f16> = s.m[1i];
+}
+)";
+
+    auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride_F16) {
+    // enable f16;
+    // struct S {
+    //   @offset(16) @stride(4)
+    //   @internal(ignore_stride_attribute)
+    //   m : mat2x2<f16>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let x : mat2x2<f16> = s.m;
+    // }
+    ProgramBuilder b;
+    b.Enable(wgsl::Extension::kF16);
+    auto* S = b.Structure(
+        "S", Vector{
+                 b.Member("m", b.ty.mat2x2<f16>(),
+                          Vector{
+                              b.MemberOffset(16_u),
+                              b.create<ast::StrideAttribute>(4u),
+                              b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+                          }),
+             });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("x", b.ty.mat2x2<f16>(), b.MemberAccessor("s", "m"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+enable f16;
+
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16u) */
+  m : mat2x2<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let x : mat2x2<f16> = s.m;
+}
+)";
+
+    auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix_F16) {
+    // enable f16;
+    // struct S {
+    //   @offset(8) @stride(32)
+    //   @internal(ignore_stride_attribute)
+    //   m : mat2x2<f16>,
+    // };
+    // @group(0) @binding(0) var<storage, read_write> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   s.m = mat2x2<f16>(vec2<f16>(1.0, 2.0), vec2<f16>(3.0, 4.0));
+    // }
+    ProgramBuilder b;
+    b.Enable(wgsl::Extension::kF16);
+    auto* S = b.Structure(
+        "S", Vector{
+                 b.Member("m", b.ty.mat2x2<f16>(),
+                          Vector{
+                              b.MemberOffset(8_u),
+                              b.create<ast::StrideAttribute>(32u),
+                              b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+                          }),
+             });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.Func(
+        "f", tint::Empty, b.ty.void_(),
+        Vector{
+            b.Assign(b.MemberAccessor("s", "m"),
+                     b.Call<mat2x2<f16>>(b.Call<vec2<f16>>(1_h, 2_h), b.Call<vec2<f16>>(3_h, 4_h))),
+        },
+        Vector{
+            b.Stage(ast::PipelineStage::kCompute),
+            b.WorkgroupSize(1_i),
+        });
+
+    auto* expect = R"(
+enable f16;
+
+struct S {
+  @size(8)
+  padding_0 : u32,
+  /* @offset(8) */
+  m : @stride(32) array<vec2<f16>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn mat2x2_stride_32_to_arr(m : mat2x2<f16>) -> @stride(32) array<vec2<f16>, 2u> {
+  return @stride(32) array<vec2<f16>, 2u>(m[0u], m[1u]);
+}
+
+@compute @workgroup_size(1i)
+fn f() {
+  s.m = mat2x2_stride_32_to_arr(mat2x2<f16>(vec2<f16>(1.0h, 2.0h), vec2<f16>(3.0h, 4.0h)));
+}
+)";
+
+    auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn_F16) {
+    // enable f16;
+    // struct S {
+    //   @offset(8) @stride(32)
+    //   @internal(ignore_stride_attribute)
+    //   m : mat2x2<f16>,
+    // };
+    // @group(0) @binding(0) var<storage, read_write> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   s.m[1] = vec2<f16>(1.0, 2.0);
+    // }
+    ProgramBuilder b;
+    b.Enable(wgsl::Extension::kF16);
+    auto* S = b.Structure(
+        "S", Vector{
+                 b.Member("m", b.ty.mat2x2<f16>(),
+                          Vector{
+                              b.MemberOffset(8_u),
+                              b.create<ast::StrideAttribute>(32u),
+                              b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+                          }),
+             });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.Func(
+        "f", tint::Empty, b.ty.void_(),
+        Vector{
+            b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.Call<vec2<f16>>(1_h, 2_h)),
+        },
+        Vector{
+            b.Stage(ast::PipelineStage::kCompute),
+            b.WorkgroupSize(1_i),
+        });
+
+    auto* expect = R"(
+enable f16;
+
+struct S {
+  @size(8)
+  padding_0 : u32,
+  /* @offset(8) */
+  m : @stride(32) array<vec2<f16>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  s.m[1i] = vec2<f16>(1.0h, 2.0h);
+}
+)";
+
+    auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
 TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
     // struct S {
     //   @offset(8) @stride(32)