reader/spirv: Handle the MatrixStride decoration
Add `transform::DecomposeStridedMatrix`, which replaces matrix members of storage or uniform buffer structures, that have a [[stride]] decoration, into an array
of N column vectors.
This is required to correctly handle `mat2x2` matrices in UBOs, as std140 rules will expect a default stride of 16 bytes, when in WGSL the default structure layout expects a stride of 8 bytes.
Bug: tint:1047
Change-Id: If5ca3c6ec087bbc1ac31a8d9a657b99bf34042a4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/59840
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index acb21d8..32a2b15 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -426,6 +426,8 @@
"transform/canonicalize_entry_point_io.h",
"transform/decompose_memory_access.cc",
"transform/decompose_memory_access.h",
+ "transform/decompose_strided_matrix.cc",
+ "transform/decompose_strided_matrix.h",
"transform/external_texture_transform.cc",
"transform/external_texture_transform.h",
"transform/first_index_offset.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index b56686d..fbf50af 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -296,6 +296,8 @@
transform/canonicalize_entry_point_io.h
transform/decompose_memory_access.cc
transform/decompose_memory_access.h
+ transform/decompose_strided_matrix.cc
+ transform/decompose_strided_matrix.h
transform/external_texture_transform.cc
transform/external_texture_transform.h
transform/first_index_offset.cc
@@ -904,6 +906,7 @@
transform/calculate_array_length_test.cc
transform/canonicalize_entry_point_io_test.cc
transform/decompose_memory_access_test.cc
+ transform/decompose_strided_matrix_test.cc
transform/external_texture_transform_test.cc
transform/first_index_offset_test.cc
transform/fold_constants_test.cc
diff --git a/src/ast/disable_validation_decoration.cc b/src/ast/disable_validation_decoration.cc
index ca59d80..846530a 100644
--- a/src/ast/disable_validation_decoration.cc
+++ b/src/ast/disable_validation_decoration.cc
@@ -40,6 +40,8 @@
return "disable_validation__entry_point_parameter";
case DisabledValidation::kIgnoreConstructibleFunctionParameter:
return "disable_validation__ignore_constructible_function_parameter";
+ case DisabledValidation::kIgnoreStrideDecoration:
+ return "disable_validation__ignore_stride";
}
return "<invalid>";
}
diff --git a/src/ast/disable_validation_decoration.h b/src/ast/disable_validation_decoration.h
index 60a9fb7..3ebf0bc 100644
--- a/src/ast/disable_validation_decoration.h
+++ b/src/ast/disable_validation_decoration.h
@@ -40,6 +40,9 @@
/// When applied to a function parameter, the validator will not
/// check if parameter type is constructible
kIgnoreConstructibleFunctionParameter,
+ /// When applied to a member decoration, a stride decoration may be applied to
+ /// non-array types.
+ kIgnoreStrideDecoration,
};
/// An internal decoration used to tell the validator to ignore specific
diff --git a/src/ast/internal_decoration.cc b/src/ast/internal_decoration.cc
index 47db806..f5e2237 100644
--- a/src/ast/internal_decoration.cc
+++ b/src/ast/internal_decoration.cc
@@ -32,7 +32,7 @@
std::ostream& out,
size_t indent) const {
make_indent(out, indent);
- out << "tint_internal(" << InternalName() << ")" << std::endl;
+ out << "tint_internal(" << InternalName() << ")";
}
} // namespace ast
diff --git a/src/reader/spirv/parser.cc b/src/reader/spirv/parser.cc
index c382d90..bcbc193 100644
--- a/src/reader/spirv/parser.cc
+++ b/src/reader/spirv/parser.cc
@@ -17,6 +17,10 @@
#include <utility>
#include "src/reader/spirv/parser_impl.h"
+#include "src/transform/decompose_strided_matrix.h"
+#include "src/transform/inline_pointer_lets.h"
+#include "src/transform/manager.h"
+#include "src/transform/simplify.h"
namespace tint {
namespace reader {
@@ -40,7 +44,19 @@
ProgramBuilder output;
CloneContext(&output, &program_with_disjoint_ast, false).Clone();
- return Program(std::move(output));
+ auto program = Program(std::move(output));
+
+ // If the generated program contains matrices with a custom MatrixStride
+ // attribute then we need to decompose these into an array of vectors
+ if (transform::DecomposeStridedMatrix::ShouldRun(&program)) {
+ transform::Manager manager;
+ manager.Add<transform::InlinePointerLets>();
+ manager.Add<transform::Simplify>();
+ manager.Add<transform::DecomposeStridedMatrix>();
+ return manager.Run(&program).program;
+ }
+
+ return program;
}
} // namespace spirv
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 7981c0a..de9fe80 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -21,6 +21,7 @@
#include "source/opt/build_module.h"
#include "src/ast/bitcast_expression.h"
+#include "src/ast/disable_validation_decoration.h"
#include "src/ast/interpolate_decoration.h"
#include "src/ast/override_decoration.h"
#include "src/ast/struct_block_decoration.h"
@@ -439,13 +440,14 @@
return "SPIR-V type " + std::to_string(type_id);
}
-ast::Decoration* ParserImpl::ConvertMemberDecoration(
+ast::DecorationList ParserImpl::ConvertMemberDecoration(
uint32_t struct_type_id,
uint32_t member_index,
+ const Type* member_ty,
const Decoration& decoration) {
if (decoration.empty()) {
Fail() << "malformed SPIR-V decoration: it's empty";
- return nullptr;
+ return {};
}
switch (decoration[0]) {
case SpvDecorationOffset:
@@ -454,38 +456,49 @@
<< "malformed Offset decoration: expected 1 literal operand, has "
<< decoration.size() - 1 << ": member " << member_index << " of "
<< ShowType(struct_type_id);
- return nullptr;
+ return {};
}
- return create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]);
+ return {
+ create<ast::StructMemberOffsetDecoration>(Source{}, decoration[1]),
+ };
case SpvDecorationNonReadable:
// WGSL doesn't have a member decoration for this. Silently drop it.
- return nullptr;
+ return {};
case SpvDecorationNonWritable:
// WGSL doesn't have a member decoration for this.
- return nullptr;
+ return {};
case SpvDecorationColMajor:
// WGSL only supports column major matrices.
- return nullptr;
+ return {};
case SpvDecorationRelaxedPrecision:
// WGSL doesn't support relaxed precision.
- return nullptr;
+ return {};
case SpvDecorationRowMajor:
Fail() << "WGSL does not support row-major matrices: can't "
"translate member "
<< member_index << " of " << ShowType(struct_type_id);
- return nullptr;
+ return {};
case SpvDecorationMatrixStride: {
if (decoration.size() != 2) {
Fail() << "malformed MatrixStride decoration: expected 1 literal "
"operand, has "
<< decoration.size() - 1 << ": member " << member_index << " of "
<< ShowType(struct_type_id);
- return nullptr;
+ return {};
}
- // TODO(dneto): Fail if the matrix stride is not allocation size of the
- // column vector of the underlying matrix. This would need to unpack
- // any levels of array-ness.
- return nullptr;
+ uint32_t stride = decoration[1];
+ uint32_t natural_stride = 0;
+ if (auto* mat = member_ty->As<Matrix>()) {
+ natural_stride = (mat->rows == 2) ? 8 : 16;
+ }
+ if (stride == natural_stride) {
+ return {};
+ }
+ return {
+ create<ast::StrideDecoration>(Source{}, decoration[1]),
+ builder_.ASTNodes().Create<ast::DisableValidationDecoration>(
+ builder_.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ };
}
default:
// TODO(dneto): Support the remaining member decorations.
@@ -493,7 +506,7 @@
}
Fail() << "unhandled member decoration: " << decoration[0] << " on member "
<< member_index << " of " << ShowType(struct_type_id);
- return nullptr;
+ return {};
}
bool ParserImpl::BuildInternalModule() {
@@ -1126,14 +1139,14 @@
// the members are non-writable.
is_non_writable = true;
} else {
- auto* ast_member_decoration =
- ConvertMemberDecoration(type_id, member_index, decoration);
+ auto decos = ConvertMemberDecoration(type_id, member_index,
+ ast_member_ty, decoration);
+ for (auto* deco : decos) {
+ ast_member_decorations.emplace_back(deco);
+ }
if (!success_) {
return nullptr;
}
- if (ast_member_decoration) {
- ast_member_decorations.push_back(ast_member_decoration);
- }
}
}
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 7a69708..6f80bcb 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -272,16 +272,19 @@
ast::Decoration* SetLocation(ast::DecorationList* decos,
ast::Decoration* replacement);
- /// Converts a SPIR-V struct member decoration. If the decoration is
- /// recognized but deliberately dropped, then returns nullptr without a
- /// diagnostic. On failure, emits a diagnostic and returns nullptr.
+ /// Converts a SPIR-V struct member decoration into a number of AST
+ /// decorations. If the decoration is recognized but deliberately dropped,
+ /// then returns an empty list without a diagnostic. On failure, emits a
+ /// diagnostic and returns an empty list.
/// @param struct_type_id the ID of the struct type
/// @param member_index the index of the member
+ /// @param member_ty the type of the member
/// @param decoration an encoded SPIR-V Decoration
- /// @returns the corresponding ast::StructuMemberDecoration
- ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
- uint32_t member_index,
- const Decoration& decoration);
+ /// @returns the AST decorations
+ ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
+ uint32_t member_index,
+ const Type* member_ty,
+ const Decoration& decoration);
/// Returns a string for the given type. If the type ID is invalid,
/// then the resulting string only names the type ID.
diff --git a/src/reader/spirv/parser_impl_convert_member_decoration_test.cc b/src/reader/spirv/parser_impl_convert_member_decoration_test.cc
index 01f29d4..d431fb5 100644
--- a/src/reader/spirv/parser_impl_convert_member_decoration_test.cc
+++ b/src/reader/spirv/parser_impl_convert_member_decoration_test.cc
@@ -25,16 +25,17 @@
TEST_F(SpvParserTest, ConvertMemberDecoration_Empty) {
auto p = parser(std::vector<uint32_t>{});
- auto* result = p->ConvertMemberDecoration(1, 1, {});
- EXPECT_EQ(result, nullptr);
+ auto result = p->ConvertMemberDecoration(1, 1, nullptr, {});
+ EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed SPIR-V decoration: it's empty"));
}
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithoutOperand) {
auto p = parser(std::vector<uint32_t>{});
- auto* result = p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset});
- EXPECT_EQ(result, nullptr);
+ auto result =
+ p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset});
+ EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
"operand, has 0: member 13 of SPIR-V type 12"));
}
@@ -42,9 +43,9 @@
TEST_F(SpvParserTest, ConvertMemberDecoration_OffsetWithTooManyOperands) {
auto p = parser(std::vector<uint32_t>{});
- auto* result =
- p->ConvertMemberDecoration(12, 13, {SpvDecorationOffset, 3, 4});
- EXPECT_EQ(result, nullptr);
+ auto result =
+ p->ConvertMemberDecoration(12, 13, nullptr, {SpvDecorationOffset, 3, 4});
+ EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("malformed Offset decoration: expected 1 literal "
"operand, has 2: member 13 of SPIR-V type 12"));
}
@@ -52,32 +53,100 @@
TEST_F(SpvParserTest, ConvertMemberDecoration_Offset) {
auto p = parser(std::vector<uint32_t>{});
- auto* result = p->ConvertMemberDecoration(1, 1, {SpvDecorationOffset, 8});
- ASSERT_NE(result, nullptr);
- EXPECT_TRUE(result->Is<ast::StructMemberOffsetDecoration>());
- auto* offset_deco = result->As<ast::StructMemberOffsetDecoration>();
+ auto result =
+ p->ConvertMemberDecoration(1, 1, nullptr, {SpvDecorationOffset, 8});
+ ASSERT_FALSE(result.empty());
+ EXPECT_TRUE(result[0]->Is<ast::StructMemberOffsetDecoration>());
+ auto* offset_deco = result[0]->As<ast::StructMemberOffsetDecoration>();
ASSERT_NE(offset_deco, nullptr);
EXPECT_EQ(offset_deco->offset(), 8u);
EXPECT_TRUE(p->error().empty());
}
+TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Natural) {
+ auto p = parser(std::vector<uint32_t>{});
+
+ spirv::F32 f32;
+ spirv::Matrix matrix(&f32, 2, 2);
+ auto result =
+ p->ConvertMemberDecoration(1, 1, &matrix, {SpvDecorationMatrixStride, 8});
+ EXPECT_TRUE(result.empty());
+ EXPECT_TRUE(p->error().empty());
+}
+
+TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x2_Stride_Custom) {
+ auto p = parser(std::vector<uint32_t>{});
+
+ spirv::F32 f32;
+ spirv::Matrix matrix(&f32, 2, 2);
+ auto result = p->ConvertMemberDecoration(1, 1, &matrix,
+ {SpvDecorationMatrixStride, 16});
+ ASSERT_FALSE(result.empty());
+ EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
+ auto* stride_deco = result[0]->As<ast::StrideDecoration>();
+ ASSERT_NE(stride_deco, nullptr);
+ EXPECT_EQ(stride_deco->stride(), 16u);
+ EXPECT_TRUE(p->error().empty());
+}
+
+TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Natural) {
+ auto p = parser(std::vector<uint32_t>{});
+
+ spirv::F32 f32;
+ spirv::Matrix matrix(&f32, 2, 4);
+ auto result = p->ConvertMemberDecoration(1, 1, &matrix,
+ {SpvDecorationMatrixStride, 16});
+ EXPECT_TRUE(result.empty());
+ EXPECT_TRUE(p->error().empty());
+}
+
+TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x4_Stride_Custom) {
+ auto p = parser(std::vector<uint32_t>{});
+
+ spirv::F32 f32;
+ spirv::Matrix matrix(&f32, 2, 4);
+ auto result = p->ConvertMemberDecoration(1, 1, &matrix,
+ {SpvDecorationMatrixStride, 64});
+ ASSERT_FALSE(result.empty());
+ EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
+ auto* stride_deco = result[0]->As<ast::StrideDecoration>();
+ ASSERT_NE(stride_deco, nullptr);
+ EXPECT_EQ(stride_deco->stride(), 64u);
+ EXPECT_TRUE(p->error().empty());
+}
+
+TEST_F(SpvParserTest, ConvertMemberDecoration_Matrix2x3_Stride_Custom) {
+ auto p = parser(std::vector<uint32_t>{});
+
+ spirv::F32 f32;
+ spirv::Matrix matrix(&f32, 2, 3);
+ auto result = p->ConvertMemberDecoration(1, 1, &matrix,
+ {SpvDecorationMatrixStride, 32});
+ ASSERT_FALSE(result.empty());
+ EXPECT_TRUE(result[0]->Is<ast::StrideDecoration>());
+ auto* stride_deco = result[0]->As<ast::StrideDecoration>();
+ ASSERT_NE(stride_deco, nullptr);
+ EXPECT_EQ(stride_deco->stride(), 32u);
+ EXPECT_TRUE(p->error().empty());
+}
+
TEST_F(SpvParserTest, ConvertMemberDecoration_RelaxedPrecision) {
// WGSL does not support relaxed precision. Drop it.
// It's functionally correct to use full precision f32 instead of
// relaxed precision f32.
auto p = parser(std::vector<uint32_t>{});
- auto* result =
- p->ConvertMemberDecoration(1, 1, {SpvDecorationRelaxedPrecision});
- EXPECT_EQ(result, nullptr);
+ auto result = p->ConvertMemberDecoration(1, 1, nullptr,
+ {SpvDecorationRelaxedPrecision});
+ EXPECT_TRUE(result.empty());
EXPECT_TRUE(p->error().empty());
}
TEST_F(SpvParserTest, ConvertMemberDecoration_UnhandledDecoration) {
auto p = parser(std::vector<uint32_t>{});
- auto* result = p->ConvertMemberDecoration(12, 13, {12345678});
- EXPECT_EQ(result, nullptr);
+ auto result = p->ConvertMemberDecoration(12, 13, nullptr, {12345678});
+ EXPECT_TRUE(result.empty());
EXPECT_THAT(p->error(), Eq("unhandled member decoration: 12345678 on member "
"13 of SPIR-V type 12"));
}
diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc
index b844a16..e884362 100644
--- a/src/reader/spirv/parser_impl_module_var_test.cc
+++ b/src/reader/spirv/parser_impl_module_var_test.cc
@@ -2015,7 +2015,7 @@
})")) << module_str;
}
-TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Dropped) {
+TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration_Natural_Dropped) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
OpDecorate %myvar DescriptorSet 0
@@ -2054,6 +2054,45 @@
})")) << module_str;
}
+TEST_F(SpvModuleScopeVarParserTest, MatrixStrideDecoration) {
+ auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
+ OpName %myvar "myvar"
+ OpDecorate %myvar DescriptorSet 0
+ OpDecorate %myvar Binding 0
+ OpDecorate %s Block
+ OpMemberDecorate %s 0 MatrixStride 64
+ OpMemberDecorate %s 0 Offset 0
+ %void = OpTypeVoid
+ %voidfn = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+ %m3v2float = OpTypeMatrix %v2float 3
+
+ %s = OpTypeStruct %m3v2float
+ %ptr_sb_s = OpTypePointer StorageBuffer %s
+ %myvar = OpVariable %ptr_sb_s StorageBuffer
+ )" + MainBody()));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << p->error();
+ EXPECT_TRUE(p->error().empty());
+ const auto module_str = p->program().to_str();
+ EXPECT_THAT(module_str, HasSubstr(R"(
+ Struct S {
+ [[block]]
+ StructMember{[[ stride 64 tint_internal(disable_validation__ignore_stride) offset 0 ]] field0: __mat_2_3__f32}
+ }
+ Variable{
+ Decorations{
+ GroupDecoration{0}
+ BindingDecoration{0}
+ }
+ myvar
+ storage
+ read_write
+ __type_name_S
+ }
+})")) << module_str;
+}
+
TEST_F(SpvModuleScopeVarParserTest, RowMajorDecoration_IsError) {
auto p = parser(test::Assemble(Preamble() + FragMain() + R"(
OpName %myvar "myvar"
@@ -2620,7 +2659,8 @@
private
undefined
__i32
- })")) <<module_str;
+ })"))
+ << module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@@ -3006,7 +3046,8 @@
private
undefined
__array__u32_1
- })")) <<module_str;
+ })"))
+ << module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@@ -3149,7 +3190,8 @@
private
undefined
__array__u32_1
- })")) <<module_str;
+ })"))
+ << module_str;
// Correct creation of value
EXPECT_THAT(module_str, HasSubstr(R"(
@@ -5543,7 +5585,6 @@
// {"NumWorkgroups", "%uint", "num_workgroups"}
// {"NumWorkgroups", "%int", "num_workgroups"}
-
TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) {
const std::string assembly =
R"(
diff --git a/src/reader/spirv/parser_impl_test_helper.h b/src/reader/spirv/parser_impl_test_helper.h
index d93b9a6..a10ac13 100644
--- a/src/reader/spirv/parser_impl_test_helper.h
+++ b/src/reader/spirv/parser_impl_test_helper.h
@@ -184,18 +184,21 @@
return impl_.GetDecorationsForMember(id, member_index);
}
- /// Converts a SPIR-V struct member decoration. If the decoration is
- /// recognized but deliberately dropped, then returns nullptr without a
- /// diagnostic. On failure, emits a diagnostic and returns nullptr.
+ /// Converts a SPIR-V struct member decoration into a number of AST
+ /// decorations. If the decoration is recognized but deliberately dropped,
+ /// then returns an empty list without a diagnostic. On failure, emits a
+ /// diagnostic and returns an empty list.
/// @param struct_type_id the ID of the struct type
/// @param member_index the index of the member
+ /// @param member_ty the type of the member
/// @param decoration an encoded SPIR-V Decoration
- /// @returns the corresponding ast::StructuMemberDecoration
- ast::Decoration* ConvertMemberDecoration(uint32_t struct_type_id,
- uint32_t member_index,
- const Decoration& decoration) {
+ /// @returns the AST decorations
+ ast::DecorationList ConvertMemberDecoration(uint32_t struct_type_id,
+ uint32_t member_index,
+ const Type* member_ty,
+ const Decoration& decoration) {
return impl_.ConvertMemberDecoration(struct_type_id, member_index,
- decoration);
+ member_ty, decoration);
}
/// For a SPIR-V ID that might define a sampler, image, or sampled image
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index e366e23..a395cd8 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -3937,13 +3937,20 @@
auto has_position = false;
ast::InvariantDecoration* invariant_attribute = nullptr;
for (auto* deco : member->Declaration()->decorations()) {
- if (!(deco->Is<ast::BuiltinDecoration>() ||
- deco->Is<ast::InterpolateDecoration>() ||
- deco->Is<ast::InvariantDecoration>() ||
- deco->Is<ast::LocationDecoration>() ||
- deco->Is<ast::StructMemberOffsetDecoration>() ||
- deco->Is<ast::StructMemberSizeDecoration>() ||
- deco->Is<ast::StructMemberAlignDecoration>())) {
+ if (!deco->IsAnyOf<ast::BuiltinDecoration, //
+ ast::InternalDecoration, //
+ ast::InterpolateDecoration, //
+ ast::InvariantDecoration, //
+ ast::LocationDecoration, //
+ ast::StructMemberOffsetDecoration, //
+ ast::StructMemberSizeDecoration, //
+ ast::StructMemberAlignDecoration>()) {
+ if (deco->Is<ast::StrideDecoration>() &&
+ IsValidationDisabled(
+ member->Declaration()->decorations(),
+ ast::DisabledValidation::kIgnoreStrideDecoration)) {
+ continue;
+ }
AddError("decoration is not valid for structure members",
deco->source());
return false;
diff --git a/src/transform/decompose_strided_matrix.cc b/src/transform/decompose_strided_matrix.cc
new file mode 100644
index 0000000..786f509
--- /dev/null
+++ b/src/transform/decompose_strided_matrix.cc
@@ -0,0 +1,251 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/decompose_strided_matrix.h"
+
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "src/program_builder.h"
+#include "src/sem/expression.h"
+#include "src/sem/member_accessor_expression.h"
+#include "src/transform/inline_pointer_lets.h"
+#include "src/transform/simplify.h"
+#include "src/utils/get_or_create.h"
+#include "src/utils/hash.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStridedMatrix);
+
+namespace tint {
+namespace transform {
+namespace {
+
+/// MatrixInfo describes a matrix member with a custom stride
+struct MatrixInfo {
+ /// The stride in bytes between columns of the matrix
+ uint32_t stride = 0;
+ /// The type of the matrix
+ sem::Matrix const* matrix = nullptr;
+
+ /// @returns a new ast::Array that holds an vector column for each row of the
+ /// matrix.
+ ast::Array* array(ProgramBuilder* b) const {
+ return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
+ matrix->columns(), stride);
+ }
+
+ /// Equality operator
+ bool operator==(const MatrixInfo& info) const {
+ return stride == info.stride && matrix == info.matrix;
+ }
+ /// Hash function
+ struct Hasher {
+ size_t operator()(const MatrixInfo& t) const {
+ return utils::Hash(t.stride, t.matrix);
+ }
+ };
+};
+
+/// Return type of the callback function of GatherCustomStrideMatrixMembers
+enum GatherResult { kContinue, kStop };
+
+/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
+/// storage and uniform structs, which are of a matrix type, and have a custom
+/// matrix stride attribute. For each matrix member found, `callback` is called.
+/// `callback` is a function with the signature:
+/// GatherResult(const sem::StructMember* member,
+/// sem::Matrix* matrix,
+/// uint32_t stride)
+/// If `callback` return GatherResult::kStop, then the scanning will immediately
+/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
+/// scanning will continue.
+template <typename F>
+void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
+ for (auto* node : program->ASTNodes().Objects()) {
+ if (auto* str = node->As<ast::Struct>()) {
+ auto* str_ty = program->Sem().Get(str);
+ if (!str_ty->UsedAs(ast::StorageClass::kUniform) &&
+ !str_ty->UsedAs(ast::StorageClass::kStorage)) {
+ continue;
+ }
+ for (auto* member : str_ty->Members()) {
+ auto* matrix = member->Type()->As<sem::Matrix>();
+ if (!matrix) {
+ continue;
+ }
+ auto* deco = ast::GetDecoration<ast::StrideDecoration>(
+ member->Declaration()->decorations());
+ if (!deco) {
+ continue;
+ }
+ uint32_t stride = deco->stride();
+ if (matrix->ColumnStride() == stride) {
+ continue;
+ }
+ if (callback(member, matrix, stride) == GatherResult::kStop) {
+ return;
+ }
+ }
+ }
+ }
+}
+
+} // namespace
+
+DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
+
+DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
+
+bool DecomposeStridedMatrix::ShouldRun(const Program* program) {
+ bool should_run = false;
+ GatherCustomStrideMatrixMembers(
+ program, [&](const sem::StructMember*, sem::Matrix*, uint32_t) {
+ should_run = true;
+ return GatherResult::kStop;
+ });
+ return should_run;
+}
+
+void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
+ if (!Requires<InlinePointerLets, Simplify>(ctx)) {
+ return;
+ }
+
+ // Scan the program for all storage and uniform structure matrix members with
+ // a custom stride attribute. Replace these matrices with an equivalent array,
+ // and populate the `decomposed` map with the members that have been replaced.
+ std::unordered_map<ast::StructMember*, MatrixInfo> decomposed;
+ GatherCustomStrideMatrixMembers(
+ ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
+ uint32_t stride) {
+ // We've got ourselves a struct member of a matrix type with a custom
+ // stride. Replace this with an array of column vectors.
+ MatrixInfo info{stride, matrix};
+ auto* replacement = ctx.dst->Member(
+ member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
+ ctx.Replace(member->Declaration(), replacement);
+ decomposed.emplace(member->Declaration(), info);
+ return GatherResult::kContinue;
+ });
+
+ // For all expressions where a single matrix column vector was indexed, we can
+ // preserve these without calling conversion functions.
+ // Example:
+ // ssbo.mat[2] -> ssbo.mat[2]
+ ctx.ReplaceAll(
+ [&](ast::ArrayAccessorExpression* expr) -> ast::ArrayAccessorExpression* {
+ if (auto* access =
+ ctx.src->Sem().Get<sem::StructMemberAccess>(expr->array())) {
+ auto it = decomposed.find(access->Member()->Declaration());
+ if (it != decomposed.end()) {
+ auto* obj = ctx.CloneWithoutTransform(expr->array());
+ auto* idx = ctx.Clone(expr->idx_expr());
+ return ctx.dst->IndexAccessor(obj, idx);
+ }
+ }
+ return nullptr;
+ });
+
+ // For all struct member accesses to the matrix on the LHS of an assignment,
+ // we need to convert the matrix to the array before assigning to the
+ // structure.
+ // Example:
+ // ssbo.mat = mat_to_arr(m)
+ std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
+ ctx.ReplaceAll([&](ast::AssignmentStatement* stmt) -> ast::Statement* {
+ if (auto* access =
+ ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs())) {
+ auto it = decomposed.find(access->Member()->Declaration());
+ if (it == decomposed.end()) {
+ return nullptr;
+ }
+ MatrixInfo info = it->second;
+ auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
+ auto name = ctx.dst->Symbols().New(
+ "mat" + std::to_string(info.matrix->columns()) + "x" +
+ std::to_string(info.matrix->rows()) + "_stride_" +
+ std::to_string(info.stride) + "_to_arr");
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
+ auto array = [&] { return info.array(ctx.dst); };
+
+ auto mat = ctx.dst->Sym("mat");
+ ast::ExpressionList columns(info.matrix->columns());
+ for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
+ columns[i] = ctx.dst->IndexAccessor(mat, i);
+ }
+ ctx.dst->Func(name,
+ {
+ ctx.dst->Param(mat, matrix()),
+ },
+ array(),
+ {
+ ctx.dst->Return(ctx.dst->Construct(array(), columns)),
+ });
+ return name;
+ });
+ auto* lhs = ctx.CloneWithoutTransform(stmt->lhs());
+ auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs()));
+ return ctx.dst->Assign(lhs, rhs);
+ }
+ return nullptr;
+ });
+
+ // For all other struct member accesses, we need to convert the array to the
+ // matrix type. Example:
+ // m = arr_to_mat(ssbo.mat)
+ std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
+ ctx.ReplaceAll([&](ast::MemberAccessorExpression* expr) -> ast::Expression* {
+ if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
+ auto it = decomposed.find(access->Member()->Declaration());
+ if (it == decomposed.end()) {
+ return nullptr;
+ }
+ MatrixInfo info = it->second;
+ auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
+ auto name = ctx.dst->Symbols().New(
+ "arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
+ std::to_string(info.matrix->rows()) + "_stride_" +
+ std::to_string(info.stride));
+
+ auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
+ auto array = [&] { return info.array(ctx.dst); };
+
+ auto arr = ctx.dst->Sym("arr");
+ ast::ExpressionList columns(info.matrix->columns());
+ for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
+ columns[i] = ctx.dst->IndexAccessor(arr, i);
+ }
+ ctx.dst->Func(
+ name,
+ {
+ ctx.dst->Param(arr, array()),
+ },
+ matrix(),
+ {
+ ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
+ });
+ return name;
+ });
+ return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
+ }
+ return nullptr;
+ });
+
+ ctx.Clone();
+}
+
+} // namespace transform
+} // namespace tint
diff --git a/src/transform/decompose_strided_matrix.h b/src/transform/decompose_strided_matrix.h
new file mode 100644
index 0000000..3283049
--- /dev/null
+++ b/src/transform/decompose_strided_matrix.h
@@ -0,0 +1,54 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
+#define SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
+
+#include "src/transform/transform.h"
+
+namespace tint {
+namespace transform {
+
+/// DecomposeStridedMatrix transforms replaces matrix members of storage or
+/// uniform buffer structures, that have a [[stride]] decoration, into an array
+/// of N column vectors.
+/// This transform is used by the SPIR-V reader to handle the SPIR-V
+/// MatrixStride decoration.
+class DecomposeStridedMatrix
+ : public Castable<DecomposeStridedMatrix, Transform> {
+ public:
+ /// Constructor
+ DecomposeStridedMatrix();
+
+ /// Destructor
+ ~DecomposeStridedMatrix() override;
+
+ /// @param program the program to inspect
+ /// @returns true if this transform should be run for the given program
+ static bool ShouldRun(const Program* program);
+
+ protected:
+ /// Runs the transform using the CloneContext built for transforming a
+ /// program. Run() is responsible for calling Clone() on the CloneContext.
+ /// @param ctx the CloneContext primed with the input program and
+ /// ProgramBuilder
+ /// @param inputs optional extra transform-specific input data
+ /// @param outputs optional extra transform-specific output data
+ void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
+};
+
+} // namespace transform
+} // namespace tint
+
+#endif // SRC_TRANSFORM_DECOMPOSE_STRIDED_MATRIX_H_
diff --git a/src/transform/decompose_strided_matrix_test.cc b/src/transform/decompose_strided_matrix_test.cc
new file mode 100644
index 0000000..6a205a6
--- /dev/null
+++ b/src/transform/decompose_strided_matrix_test.cc
@@ -0,0 +1,727 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/decompose_strided_matrix.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "src/ast/disable_validation_decoration.h"
+#include "src/program_builder.h"
+#include "src/transform/inline_pointer_lets.h"
+#include "src/transform/simplify.h"
+#include "src/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using DecomposeStridedMatrixTest = TransformTest;
+using f32 = ProgramBuilder::f32;
+
+TEST_F(DecomposeStridedMatrixTest, Empty) {
+ auto* src = R"()";
+ auto* expect = src;
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, MissingDependencyInlinePointerLets) {
+ auto* src = R"()";
+ auto* expect =
+ R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::InlinePointerLets but the dependency was not run)";
+
+ auto got = Run<Simplify, DecomposeStridedMatrix>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, MissingDependencySimplify) {
+ auto* src = R"()";
+ auto* expect =
+ R"(error: tint::transform::DecomposeStridedMatrix depends on tint::transform::Simplify but the dependency was not run)";
+
+ auto got = Run<InlinePointerLets, DecomposeStridedMatrix>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
+ // [[block]]
+ // struct S {
+ // [[offset(16), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // [[group(0), binding(0)]] var<uniform> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(16),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ },
+ {
+ b.StructBlock(),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
+ b.GroupAndBinding(0, 0));
+ b.Func(
+ "f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+[[block]]
+struct S {
+ [[size(16)]]
+ padding : u32;
+ m : [[stride(32)]] array<vec2<f32>, 2>;
+};
+
+[[group(0), binding(0)]] var<uniform> s : S;
+
+fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
+ return mat2x2<f32>(arr[0u], arr[1u]);
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
+ // [[block]]
+ // struct S {
+ // [[offset(16), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // [[group(0), binding(0)]] var<uniform> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // let x : vec2<f32> = s.m[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(16),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ },
+ {
+ b.StructBlock(),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
+ b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Const("x", b.ty.vec2<f32>(),
+ b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+[[block]]
+struct S {
+ [[size(16)]]
+ padding : u32;
+ m : [[stride(32)]] array<vec2<f32>, 2>;
+};
+
+[[group(0), binding(0)]] var<uniform> s : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let x : vec2<f32> = s.m[1];
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
+ // [[block]]
+ // struct S {
+ // [[offset(16), stride(8)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // [[group(0), binding(0)]] var<uniform> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(16),
+ b.create<ast::StrideDecoration>(8),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ },
+ {
+ b.StructBlock(),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kUniform,
+ b.GroupAndBinding(0, 0));
+ b.Func(
+ "f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+[[block]]
+struct S {
+ [[size(16)]]
+ padding : u32;
+ [[stride(8), internal(disable_validation__ignore_stride)]]
+ m : mat2x2<f32>;
+};
+
+[[group(0), binding(0)]] var<uniform> s : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let x : mat2x2<f32> = s.m;
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
+ // [[block]]
+ // struct S {
+ // [[offset(8), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // [[group(0), binding(0)]] var<storage, read_write> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(8),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ },
+ {
+ b.StructBlock(),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
+ b.Func(
+ "f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+[[block]]
+struct S {
+ [[size(8)]]
+ padding : u32;
+ m : [[stride(32)]] array<vec2<f32>, 2>;
+};
+
+[[group(0), binding(0)]] var<storage, read_write> s : S;
+
+fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
+ return mat2x2<f32>(arr[0u], arr[1u]);
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let x : mat2x2<f32> = arr_to_mat2x2_stride_32(s.m);
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
+ // [[block]]
+ // struct S {
+ // [[offset(16), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // [[group(0), binding(0)]] var<storage, read_write> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // let x : vec2<f32> = s.m[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(16),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ },
+ {
+ b.StructBlock(),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Const("x", b.ty.vec2<f32>(),
+ b.IndexAccessor(b.MemberAccessor("s", "m"), 1))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+[[block]]
+struct S {
+ [[size(16)]]
+ padding : u32;
+ m : [[stride(32)]] array<vec2<f32>, 2>;
+};
+
+[[group(0), binding(0)]] var<storage, read_write> s : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let x : vec2<f32> = s.m[1];
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
+ // [[block]]
+ // struct S {
+ // [[offset(8), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // [[group(0), binding(0)]] var<storage, read_write> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(8),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ },
+ {
+ b.StructBlock(),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Assign(b.MemberAccessor("s", "m"),
+ b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
+ b.vec2<f32>(3.0f, 4.0f))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+[[block]]
+struct S {
+ [[size(8)]]
+ padding : u32;
+ m : [[stride(32)]] array<vec2<f32>, 2>;
+};
+
+[[group(0), binding(0)]] var<storage, read_write> s : S;
+
+fn mat2x2_stride_32_to_arr(mat : mat2x2<f32>) -> [[stride(32)]] array<vec2<f32>, 2> {
+ return [[stride(32)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
+ // [[block]]
+ // struct S {
+ // [[offset(8), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // [[group(0), binding(0)]] var<storage, read_write> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // s.m[1] = vec2<f32>(1.0, 2.0);
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(8),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ },
+ {
+ b.StructBlock(),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1),
+ b.vec2<f32>(1.0f, 2.0f)),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+[[block]]
+struct S {
+ [[size(8)]]
+ padding : u32;
+ m : [[stride(32)]] array<vec2<f32>, 2>;
+};
+
+[[group(0), binding(0)]] var<storage, read_write> s : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ s.m[1] = vec2<f32>(1.0, 2.0);
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
+ // [[block]]
+ // struct S {
+ // [[offset(8), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // [[group(0), binding(0)]] var<storage, read_write> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // let a = &s.m;
+ // let b = &*&*(a);
+ // let x = *b;
+ // let y = (*b)[1];
+ // let z = x[1];
+ // (*b) = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // (*b)[1] = vec2<f32>(5.0, 6.0);
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(8),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ },
+ {
+ b.StructBlock(),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kStorage,
+ ast::Access::kReadWrite, b.GroupAndBinding(0, 0));
+ b.Func(
+ "f", {}, b.ty.void_(),
+ {
+ b.Decl(
+ b.Const("a", nullptr, b.AddressOf(b.MemberAccessor("s", "m")))),
+ b.Decl(b.Const("b", nullptr,
+ b.AddressOf(b.Deref(b.AddressOf(b.Deref("a")))))),
+ b.Decl(b.Const("x", nullptr, b.Deref("b"))),
+ b.Decl(b.Const("y", nullptr, b.IndexAccessor(b.Deref("b"), 1))),
+ b.Decl(b.Const("z", nullptr, b.IndexAccessor("x", 1))),
+ b.Assign(b.Deref("b"), b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
+ b.vec2<f32>(3.0f, 4.0f))),
+ b.Assign(b.IndexAccessor(b.Deref("b"), 1), b.vec2<f32>(5.0f, 6.0f)),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+[[block]]
+struct S {
+ [[size(8)]]
+ padding : u32;
+ m : [[stride(32)]] array<vec2<f32>, 2>;
+};
+
+[[group(0), binding(0)]] var<storage, read_write> s : S;
+
+fn arr_to_mat2x2_stride_32(arr : [[stride(32)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
+ return mat2x2<f32>(arr[0u], arr[1u]);
+}
+
+fn mat2x2_stride_32_to_arr(mat : mat2x2<f32>) -> [[stride(32)]] array<vec2<f32>, 2> {
+ return [[stride(32)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let x = arr_to_mat2x2_stride_32(s.m);
+ let y = s.m[1];
+ let z = x[1];
+ s.m = mat2x2_stride_32_to_arr(mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)));
+ s.m[1] = vec2<f32>(5.0, 6.0);
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
+ // struct S {
+ // [[offset(8), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // var<private> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // let x : mat2x2<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(8),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
+ b.Func(
+ "f", {}, b.ty.void_(),
+ {
+ b.Decl(b.Const("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+struct S {
+ [[size(8)]]
+ padding : u32;
+ [[stride(32), internal(disable_validation__ignore_stride)]]
+ m : mat2x2<f32>;
+};
+
+var<private> s : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let x : mat2x2<f32> = s.m;
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
+ // struct S {
+ // [[offset(8), stride(32)]]
+ // [[internal(ignore_stride_decoration)]]
+ // m : mat2x2<f32>;
+ // };
+ // var<private> s : S;
+ //
+ // [[stage(compute), workgroup_size(1)]]
+ // fn f() {
+ // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S",
+ {
+ b.Member(
+ "m", b.ty.mat2x2<f32>(),
+ {
+ b.create<ast::StructMemberOffsetDecoration>(8),
+ b.create<ast::StrideDecoration>(32),
+ b.ASTNodes().Create<ast::DisableValidationDecoration>(
+ b.ID(), ast::DisabledValidation::kIgnoreStrideDecoration),
+ }),
+ });
+ b.Global("s", b.ty.Of(S), ast::StorageClass::kPrivate);
+ b.Func("f", {}, b.ty.void_(),
+ {
+ b.Assign(b.MemberAccessor("s", "m"),
+ b.mat2x2<f32>(b.vec2<f32>(1.0f, 2.0f),
+ b.vec2<f32>(3.0f, 4.0f))),
+ },
+ {
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1),
+ });
+
+ auto* expect = R"(
+struct S {
+ [[size(8)]]
+ padding : u32;
+ [[stride(32), internal(disable_validation__ignore_stride)]]
+ m : mat2x2<f32>;
+};
+
+var<private> s : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+}
+)";
+
+ auto got = Run<InlinePointerLets, Simplify, DecomposeStridedMatrix>(
+ Program(std::move(b)));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace transform
+} // namespace tint
diff --git a/src/transform/test_helper.h b/src/transform/test_helper.h
index 4b35a71..0c4618a 100644
--- a/src/transform/test_helper.h
+++ b/src/transform/test_helper.h
@@ -59,6 +59,16 @@
// Keep this pointer alive after Transform() returns
files_.emplace_back(std::move(file));
+ return Run<TRANSFORMS...>(std::move(program), data);
+ }
+
+ /// Transforms and returns program `program`, transformed using a transform of
+ /// type `TRANSFORM`.
+ /// @param program the input Program
+ /// @param data the optional DataMap to pass to Transform::Run()
+ /// @return the transformed output
+ template <typename... TRANSFORMS>
+ Output Run(Program&& program, const DataMap& data = {}) {
if (!program.IsValid()) {
return Output(std::move(program));
}
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 698efcd..4fadc64 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -696,6 +696,8 @@
out << "size(" << size->size() << ")";
} else if (auto* align = deco->As<ast::StructMemberAlignDecoration>()) {
out << "align(" << align->align() << ")";
+ } else if (auto* stride = deco->As<ast::StrideDecoration>()) {
+ out << "stride(" << stride->stride() << ")";
} else if (auto* internal = deco->As<ast::InternalDecoration>()) {
out << "internal(" << internal->InternalName() << ")";
} else {
diff --git a/test/BUILD.gn b/test/BUILD.gn
index bd33d62..67d2d39 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -289,6 +289,7 @@
"../src/transform/calculate_array_length_test.cc",
"../src/transform/canonicalize_entry_point_io_test.cc",
"../src/transform/decompose_memory_access_test.cc",
+ "../src/transform/decompose_strided_matrix_test.cc",
"../src/transform/external_texture_transform_test.cc",
"../src/transform/first_index_offset_test.cc",
"../src/transform/fold_constants_test.cc",
diff --git a/test/layout/storage/mat2x2/f32.wgsl b/test/layout/storage/mat2x2/f32.wgsl
new file mode 100644
index 0000000..c0cdb01
--- /dev/null
+++ b/test/layout/storage/mat2x2/f32.wgsl
@@ -0,0 +1,11 @@
+[[block]]
+struct SSBO {
+ m : mat2x2<f32>;
+};
+[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let v = ssbo.m;
+ ssbo.m = v;
+}
diff --git a/test/layout/storage/mat2x2/f32.wgsl.expected.hlsl b/test/layout/storage/mat2x2/f32.wgsl.expected.hlsl
new file mode 100644
index 0000000..49afc06
--- /dev/null
+++ b/test/layout/storage/mat2x2/f32.wgsl.expected.hlsl
@@ -0,0 +1,17 @@
+RWByteAddressBuffer ssbo : register(u0, space0);
+
+float2x2 tint_symbol(RWByteAddressBuffer buffer, uint offset) {
+ return float2x2(asfloat(buffer.Load2((offset + 0u))), asfloat(buffer.Load2((offset + 8u))));
+}
+
+void tint_symbol_2(RWByteAddressBuffer buffer, uint offset, float2x2 value) {
+ buffer.Store2((offset + 0u), asuint(value[0u]));
+ buffer.Store2((offset + 8u), asuint(value[1u]));
+}
+
+[numthreads(1, 1, 1)]
+void f() {
+ const float2x2 v = tint_symbol(ssbo, 0u);
+ tint_symbol_2(ssbo, 0u, v);
+ return;
+}
diff --git a/test/layout/storage/mat2x2/f32.wgsl.expected.msl b/test/layout/storage/mat2x2/f32.wgsl.expected.msl
new file mode 100644
index 0000000..32de921
--- /dev/null
+++ b/test/layout/storage/mat2x2/f32.wgsl.expected.msl
@@ -0,0 +1,13 @@
+#include <metal_stdlib>
+
+using namespace metal;
+struct SSBO {
+ /* 0x0000 */ float2x2 m;
+};
+
+kernel void f(device SSBO& ssbo [[buffer(0)]]) {
+ float2x2 const v = ssbo.m;
+ ssbo.m = v;
+ return;
+}
+
diff --git a/test/layout/storage/mat2x2/f32.wgsl.expected.spvasm b/test/layout/storage/mat2x2/f32.wgsl.expected.spvasm
new file mode 100644
index 0000000..2e4ee55
--- /dev/null
+++ b/test/layout/storage/mat2x2/f32.wgsl.expected.spvasm
@@ -0,0 +1,38 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 17
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %f "f"
+ OpExecutionMode %f LocalSize 1 1 1
+ OpName %SSBO "SSBO"
+ OpMemberName %SSBO 0 "m"
+ OpName %ssbo "ssbo"
+ OpName %f "f"
+ OpDecorate %SSBO Block
+ OpMemberDecorate %SSBO 0 Offset 0
+ OpMemberDecorate %SSBO 0 ColMajor
+ OpMemberDecorate %SSBO 0 MatrixStride 8
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+%mat2v2float = OpTypeMatrix %v2float 2
+ %SSBO = OpTypeStruct %mat2v2float
+%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
+ %ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
+ %void = OpTypeVoid
+ %7 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+%_ptr_StorageBuffer_mat2v2float = OpTypePointer StorageBuffer %mat2v2float
+ %f = OpFunction %void None %7
+ %10 = OpLabel
+ %14 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
+ %15 = OpLoad %mat2v2float %14
+ %16 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
+ OpStore %16 %15
+ OpReturn
+ OpFunctionEnd
diff --git a/test/layout/storage/mat2x2/f32.wgsl.expected.wgsl b/test/layout/storage/mat2x2/f32.wgsl.expected.wgsl
new file mode 100644
index 0000000..b4f1f51
--- /dev/null
+++ b/test/layout/storage/mat2x2/f32.wgsl.expected.wgsl
@@ -0,0 +1,12 @@
+[[block]]
+struct SSBO {
+ m : mat2x2<f32>;
+};
+
+[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+ let v = ssbo.m;
+ ssbo.m = v;
+}
diff --git a/test/layout/storage/mat2x2/stride/16.spvasm b/test/layout/storage/mat2x2/stride/16.spvasm
new file mode 100644
index 0000000..fbf1122
--- /dev/null
+++ b/test/layout/storage/mat2x2/stride/16.spvasm
@@ -0,0 +1,33 @@
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %f "f"
+ OpExecutionMode %f LocalSize 1 1 1
+ OpName %SSBO "SSBO"
+ OpMemberName %SSBO 0 "m"
+ OpName %ssbo "ssbo"
+ OpName %f "f"
+ OpDecorate %SSBO Block
+ OpMemberDecorate %SSBO 0 Offset 0
+ OpMemberDecorate %SSBO 0 ColMajor
+ OpMemberDecorate %SSBO 0 MatrixStride 16
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+ %mat2v2float = OpTypeMatrix %v2float 2
+ %SSBO = OpTypeStruct %mat2v2float
+ %_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
+ %ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
+ %void = OpTypeVoid
+ %7 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+%_ptr_StorageBuffer_mat2v2float = OpTypePointer StorageBuffer %mat2v2float
+ %f = OpFunction %void None %7
+ %10 = OpLabel
+ %14 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
+ %15 = OpLoad %mat2v2float %14
+ %16 = OpAccessChain %_ptr_StorageBuffer_mat2v2float %ssbo %uint_0
+ OpStore %16 %15
+ OpReturn
+ OpFunctionEnd
diff --git a/test/layout/storage/mat2x2/stride/16.spvasm.expected.hlsl b/test/layout/storage/mat2x2/stride/16.spvasm.expected.hlsl
new file mode 100644
index 0000000..f9ad3ad
--- /dev/null
+++ b/test/layout/storage/mat2x2/stride/16.spvasm.expected.hlsl
@@ -0,0 +1,47 @@
+struct tint_padded_array_element {
+ float2 el;
+};
+
+RWByteAddressBuffer ssbo : register(u0, space0);
+
+float2x2 arr_to_mat2x2_stride_16(tint_padded_array_element arr[2]) {
+ return float2x2(arr[0u].el, arr[1u].el);
+}
+
+typedef tint_padded_array_element mat2x2_stride_16_to_arr_ret[2];
+mat2x2_stride_16_to_arr_ret mat2x2_stride_16_to_arr(float2x2 mat) {
+ const tint_padded_array_element tint_symbol_4[2] = {{mat[0u]}, {mat[1u]}};
+ return tint_symbol_4;
+}
+
+typedef tint_padded_array_element tint_symbol_ret[2];
+tint_symbol_ret tint_symbol(RWByteAddressBuffer buffer, uint offset) {
+ tint_padded_array_element arr_1[2] = (tint_padded_array_element[2])0;
+ {
+ for(uint i = 0u; (i < 2u); i = (i + 1u)) {
+ arr_1[i].el = asfloat(buffer.Load2((offset + (i * 16u))));
+ }
+ }
+ return arr_1;
+}
+
+void tint_symbol_2(RWByteAddressBuffer buffer, uint offset, tint_padded_array_element value[2]) {
+ tint_padded_array_element array[2] = value;
+ {
+ for(uint i_1 = 0u; (i_1 < 2u); i_1 = (i_1 + 1u)) {
+ buffer.Store2((offset + (i_1 * 16u)), asuint(array[i_1].el));
+ }
+ }
+}
+
+void f_1() {
+ const float2x2 x_15 = arr_to_mat2x2_stride_16(tint_symbol(ssbo, 0u));
+ tint_symbol_2(ssbo, 0u, mat2x2_stride_16_to_arr(x_15));
+ return;
+}
+
+[numthreads(1, 1, 1)]
+void f() {
+ f_1();
+ return;
+}
diff --git a/test/layout/storage/mat2x2/stride/16.spvasm.expected.msl b/test/layout/storage/mat2x2/stride/16.spvasm.expected.msl
new file mode 100644
index 0000000..9bb899f
--- /dev/null
+++ b/test/layout/storage/mat2x2/stride/16.spvasm.expected.msl
@@ -0,0 +1,34 @@
+#include <metal_stdlib>
+
+using namespace metal;
+struct tint_padded_array_element {
+ /* 0x0000 */ packed_float2 el;
+ /* 0x0008 */ int8_t tint_pad[8];
+};
+struct tint_array_wrapper {
+ /* 0x0000 */ tint_padded_array_element arr[2];
+};
+struct SSBO {
+ /* 0x0000 */ tint_array_wrapper m;
+};
+
+float2x2 arr_to_mat2x2_stride_16(tint_array_wrapper arr) {
+ return float2x2(arr.arr[0u].el, arr.arr[1u].el);
+}
+
+tint_array_wrapper mat2x2_stride_16_to_arr(float2x2 mat) {
+ tint_array_wrapper const tint_symbol = {.arr={{.el=mat[0u]}, {.el=mat[1u]}}};
+ return tint_symbol;
+}
+
+void f_1(device SSBO& ssbo) {
+ float2x2 const x_15 = arr_to_mat2x2_stride_16(ssbo.m);
+ ssbo.m = mat2x2_stride_16_to_arr(x_15);
+ return;
+}
+
+kernel void f(device SSBO& ssbo [[buffer(0)]]) {
+ f_1(ssbo);
+ return;
+}
+
diff --git a/test/layout/storage/mat2x2/stride/16.spvasm.expected.spvasm b/test/layout/storage/mat2x2/stride/16.spvasm.expected.spvasm
new file mode 100644
index 0000000..94c280b
--- /dev/null
+++ b/test/layout/storage/mat2x2/stride/16.spvasm.expected.spvasm
@@ -0,0 +1,70 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 39
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %f "f"
+ OpExecutionMode %f LocalSize 1 1 1
+ OpName %SSBO "SSBO"
+ OpMemberName %SSBO 0 "m"
+ OpName %ssbo "ssbo"
+ OpName %arr_to_mat2x2_stride_16 "arr_to_mat2x2_stride_16"
+ OpName %arr "arr"
+ OpName %mat2x2_stride_16_to_arr "mat2x2_stride_16_to_arr"
+ OpName %mat "mat"
+ OpName %f_1 "f_1"
+ OpName %f "f"
+ OpDecorate %SSBO Block
+ OpMemberDecorate %SSBO 0 Offset 0
+ OpDecorate %_arr_v2float_uint_2 ArrayStride 16
+ OpDecorate %ssbo DescriptorSet 0
+ OpDecorate %ssbo Binding 0
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+%_arr_v2float_uint_2 = OpTypeArray %v2float %uint_2
+ %SSBO = OpTypeStruct %_arr_v2float_uint_2
+%_ptr_StorageBuffer_SSBO = OpTypePointer StorageBuffer %SSBO
+ %ssbo = OpVariable %_ptr_StorageBuffer_SSBO StorageBuffer
+%mat2v2float = OpTypeMatrix %v2float 2
+ %9 = OpTypeFunction %mat2v2float %_arr_v2float_uint_2
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %19 = OpTypeFunction %_arr_v2float_uint_2 %mat2v2float
+ %void = OpTypeVoid
+ %26 = OpTypeFunction %void
+%_ptr_StorageBuffer__arr_v2float_uint_2 = OpTypePointer StorageBuffer %_arr_v2float_uint_2
+%arr_to_mat2x2_stride_16 = OpFunction %mat2v2float None %9
+ %arr = OpFunctionParameter %_arr_v2float_uint_2
+ %13 = OpLabel
+ %15 = OpCompositeExtract %v2float %arr 0
+ %17 = OpCompositeExtract %v2float %arr 1
+ %18 = OpCompositeConstruct %mat2v2float %15 %17
+ OpReturnValue %18
+ OpFunctionEnd
+%mat2x2_stride_16_to_arr = OpFunction %_arr_v2float_uint_2 None %19
+ %mat = OpFunctionParameter %mat2v2float
+ %22 = OpLabel
+ %23 = OpCompositeExtract %v2float %mat 0
+ %24 = OpCompositeExtract %v2float %mat 1
+ %25 = OpCompositeConstruct %_arr_v2float_uint_2 %23 %24
+ OpReturnValue %25
+ OpFunctionEnd
+ %f_1 = OpFunction %void None %26
+ %29 = OpLabel
+ %32 = OpAccessChain %_ptr_StorageBuffer__arr_v2float_uint_2 %ssbo %uint_0
+ %33 = OpLoad %_arr_v2float_uint_2 %32
+ %30 = OpFunctionCall %mat2v2float %arr_to_mat2x2_stride_16 %33
+ %34 = OpAccessChain %_ptr_StorageBuffer__arr_v2float_uint_2 %ssbo %uint_0
+ %35 = OpFunctionCall %_arr_v2float_uint_2 %mat2x2_stride_16_to_arr %30
+ OpStore %34 %35
+ OpReturn
+ OpFunctionEnd
+ %f = OpFunction %void None %26
+ %37 = OpLabel
+ %38 = OpFunctionCall %void %f_1
+ OpReturn
+ OpFunctionEnd
diff --git a/test/layout/storage/mat2x2/stride/16.spvasm.expected.wgsl b/test/layout/storage/mat2x2/stride/16.spvasm.expected.wgsl
new file mode 100644
index 0000000..0ac837a
--- /dev/null
+++ b/test/layout/storage/mat2x2/stride/16.spvasm.expected.wgsl
@@ -0,0 +1,25 @@
+[[block]]
+struct SSBO {
+ m : [[stride(16)]] array<vec2<f32>, 2>;
+};
+
+[[group(0), binding(0)]] var<storage, read_write> ssbo : SSBO;
+
+fn arr_to_mat2x2_stride_16(arr : [[stride(16)]] array<vec2<f32>, 2>) -> mat2x2<f32> {
+ return mat2x2<f32>(arr[0u], arr[1u]);
+}
+
+fn mat2x2_stride_16_to_arr(mat : mat2x2<f32>) -> [[stride(16)]] array<vec2<f32>, 2> {
+ return [[stride(16)]] array<vec2<f32>, 2>(mat[0u], mat[1u]);
+}
+
+fn f_1() {
+ let x_15 : mat2x2<f32> = arr_to_mat2x2_stride_16(ssbo.m);
+ ssbo.m = mat2x2_stride_16_to_arr(x_15);
+ return;
+}
+
+[[stage(compute), workgroup_size(1, 1, 1)]]
+fn f() {
+ f_1();
+}