[spirv-reader] Support strided f16 matrices
The DecomposeStridedMatrix transform was hardcoded to create f32
arrays, so just update it to use f16 if needed and add some tests.
Bug: 377728743
Change-Id: I882047bcb7ee855053718076a28661356201e9ca
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/215614
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
index a50c38c..8509871 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.cc
@@ -57,7 +57,15 @@
/// @returns the identifier of an array that holds an vector column for each row of the matrix.
ast::Type array(ast::Builder* b) const {
- return b->ty.array(b->ty.vec<f32>(matrix->Rows()), u32(matrix->Columns()),
+ ast::Type col_type;
+ if (matrix->Type()->Is<core::type::F32>()) {
+ col_type = b->ty.vec<f32>(matrix->Rows());
+ } else if (matrix->Type()->Is<core::type::F16>()) {
+ col_type = b->ty.vec<f16>(matrix->Rows());
+ } else {
+ TINT_UNREACHABLE();
+ }
+ return b->ty.array(col_type, u32(matrix->Columns()),
Vector{
b->Stride(stride),
});
diff --git a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
index a69ff9d..e0dc2c8 100644
--- a/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
@@ -637,6 +637,305 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_F16) {
+ // enable f16;
+ // struct S {
+ // @offset(16) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f16>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f16> = s.m;
+ // }
+ ProgramBuilder b;
+ b.Enable(wgsl::Extension::kF16);
+ auto* S = b.Structure(
+ "S", Vector{
+ b.Member("m", b.ty.mat2x2<f16>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.create<ast::StrideAttribute>(32u),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("x", b.ty.mat2x2<f16>(), b.MemberAccessor("s", "m"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16) */
+ m : @stride(32) array<vec2<f16>, 2u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+fn arr_to_mat2x2_stride_32(arr : @stride(32) array<vec2<f16>, 2u>) -> mat2x2<f16> {
+ return mat2x2<f16>(arr[0u], arr[1u]);
+}
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : mat2x2<f16> = arr_to_mat2x2_stride_32(s.m);
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn_F16) {
+ // enable f16;
+ // struct S {
+ // @offset(16) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f16>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : vec2<f32> = s.m[1];
+ // }
+ ProgramBuilder b;
+ b.Enable(wgsl::Extension::kF16);
+ auto* S = b.Structure(
+ "S", Vector{
+ b.Member("m", b.ty.mat2x2<f16>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.create<ast::StrideAttribute>(32u),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("x", b.ty.vec2<f16>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16) */
+ m : @stride(32) array<vec2<f16>, 2u>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : vec2<f16> = s.m[1i];
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride_F16) {
+ // enable f16;
+ // struct S {
+ // @offset(16) @stride(4)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f16>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x2<f16> = s.m;
+ // }
+ ProgramBuilder b;
+ b.Enable(wgsl::Extension::kF16);
+ auto* S = b.Structure(
+ "S", Vector{
+ b.Member("m", b.ty.mat2x2<f16>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.create<ast::StrideAttribute>(4u),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("x", b.ty.mat2x2<f16>(), b.MemberAccessor("s", "m"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16u) */
+ m : mat2x2<f16>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : mat2x2<f16> = s.m;
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix_F16) {
+ // enable f16;
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f16>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // s.m = mat2x2<f16>(vec2<f16>(1.0, 2.0), vec2<f16>(3.0, 4.0));
+ // }
+ ProgramBuilder b;
+ b.Enable(wgsl::Extension::kF16);
+ auto* S = b.Structure(
+ "S", Vector{
+ b.Member("m", b.ty.mat2x2<f16>(),
+ Vector{
+ b.MemberOffset(8_u),
+ b.create<ast::StrideAttribute>(32u),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Assign(b.MemberAccessor("s", "m"),
+ b.Call<mat2x2<f16>>(b.Call<vec2<f16>>(1_h, 2_h), b.Call<vec2<f16>>(3_h, 4_h))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(8)
+ padding_0 : u32,
+ /* @offset(8) */
+ m : @stride(32) array<vec2<f16>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+fn mat2x2_stride_32_to_arr(m : mat2x2<f16>) -> @stride(32) array<vec2<f16>, 2u> {
+ return @stride(32) array<vec2<f16>, 2u>(m[0u], m[1u]);
+}
+
+@compute @workgroup_size(1i)
+fn f() {
+ s.m = mat2x2_stride_32_to_arr(mat2x2<f16>(vec2<f16>(1.0h, 2.0h), vec2<f16>(3.0h, 4.0h)));
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn_F16) {
+ // enable f16;
+ // struct S {
+ // @offset(8) @stride(32)
+ // @internal(ignore_stride_attribute)
+ // m : mat2x2<f16>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // s.m[1] = vec2<f16>(1.0, 2.0);
+ // }
+ ProgramBuilder b;
+ b.Enable(wgsl::Extension::kF16);
+ auto* S = b.Structure(
+ "S", Vector{
+ b.Member("m", b.ty.mat2x2<f16>(),
+ Vector{
+ b.MemberOffset(8_u),
+ b.create<ast::StrideAttribute>(32u),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.Call<vec2<f16>>(1_h, 2_h)),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+enable f16;
+
+struct S {
+ @size(8)
+ padding_0 : u32,
+ /* @offset(8) */
+ m : @stride(32) array<vec2<f16>, 2u>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ s.m[1i] = vec2<f16>(1.0h, 2.0h);
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, DecomposeStridedMatrix>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
// struct S {
// @offset(8) @stride(32)