[spirv-reader] Add TransposeRowMajor transform
Replace structure members that have @row_major attributes with the
transposed equivalents, and then update all relevant accessor
expressions to either transpose the whole matrix or swap accessor
indices as necessary.
Bug: 364267168
Change-Id: I3f25276a6ec96b68bf0bab05ee79b482f3a978e1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/207418
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel b/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
index 7bb61d7..4c3560d 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
@@ -44,6 +44,7 @@
"decompose_strided_matrix.cc",
"fold_trivial_lets.cc",
"pass_workgroup_id_as_argument.cc",
+ "transpose_row_major.cc",
],
hdrs = [
"atomics.h",
@@ -51,6 +52,7 @@
"decompose_strided_matrix.h",
"fold_trivial_lets.h",
"pass_workgroup_id_as_argument.h",
+ "transpose_row_major.h",
],
deps = [
"//src/tint/api/common",
@@ -92,6 +94,7 @@
"decompose_strided_matrix_test.cc",
"fold_trivial_lets_test.cc",
"pass_workgroup_id_as_argument_test.cc",
+ "transpose_row_major_test.cc",
],
deps = [
"//src/tint/api/common",
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake b/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
index 361645f..5a87287 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
@@ -51,6 +51,8 @@
lang/spirv/reader/ast_lower/fold_trivial_lets.h
lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h
+ lang/spirv/reader/ast_lower/transpose_row_major.cc
+ lang/spirv/reader/ast_lower/transpose_row_major.h
)
tint_target_add_dependencies(tint_lang_spirv_reader_ast_lower lib
@@ -98,6 +100,7 @@
lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
lang/spirv/reader/ast_lower/fold_trivial_lets_test.cc
lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
+ lang/spirv/reader/ast_lower/transpose_row_major_test.cc
)
tint_target_add_dependencies(tint_lang_spirv_reader_ast_lower_test test
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.gn b/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
index 4acb4a2..6252459 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
@@ -55,6 +55,8 @@
"fold_trivial_lets.h",
"pass_workgroup_id_as_argument.cc",
"pass_workgroup_id_as_argument.h",
+ "transpose_row_major.cc",
+ "transpose_row_major.h",
]
deps = [
"${dawn_root}/src/utils:utils",
@@ -96,6 +98,7 @@
"decompose_strided_matrix_test.cc",
"fold_trivial_lets_test.cc",
"pass_workgroup_id_as_argument_test.cc",
+ "transpose_row_major_test.cc",
]
deps = [
"${dawn_root}/src/utils:utils",
diff --git a/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.cc b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.cc
new file mode 100644
index 0000000..f77c7c1
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.cc
@@ -0,0 +1,344 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/program/clone_context.h"
+#include "src/tint/lang/wgsl/program/program_builder.h"
+#include "src/tint/lang/wgsl/resolver/resolve.h"
+#include "src/tint/lang/wgsl/sem/call.h"
+#include "src/tint/lang/wgsl/sem/index_accessor_expression.h"
+#include "src/tint/lang/wgsl/sem/load.h"
+#include "src/tint/lang/wgsl/sem/member_accessor_expression.h"
+#include "src/tint/utils/rtti/switch.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+
+TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::TransposeRowMajor);
+
+namespace tint::spirv::reader {
+
+TransposeRowMajor::TransposeRowMajor() = default;
+
+TransposeRowMajor::~TransposeRowMajor() = default;
+
+/// PIMPL state for the transform.
+struct TransposeRowMajor::State {
+ /// The source program
+ const Program& src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ program::CloneContext ctx = {&b, &src, /* auto_clone_symbols */ true};
+ /// The semantic info.
+ const sem::Info& sem = src.Sem();
+
+ /// Map from matrix reference to column load helper function.
+ Hashmap<const core::type::Type*, Symbol, 4> column_load_helpers;
+
+ /// Map from matrix reference to column store helper function.
+ Hashmap<const core::type::Type*, Symbol, 4> column_store_helpers;
+
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program& program) : src(program) {}
+
+ ApplyResult Run() {
+ // Scan the program for all storage and uniform structure matrix members with the @row_major
+ // attribute. Replace these matrices with a transposed version, and populate the
+ // `transposed_members` map with the members that have been replaced.
+ Hashset<const core::type::StructMember*, 8> transposed_members;
+ for (auto* decl : src.AST().TypeDecls()) {
+ if (auto* str = decl->As<ast::Struct>()) {
+ auto* str_ty = src.Sem().Get(str);
+ if (!str_ty->UsedAs(core::AddressSpace::kUniform) &&
+ !str_ty->UsedAs(core::AddressSpace::kStorage)) {
+ continue;
+ }
+ for (auto* member : str_ty->Members()) {
+ auto* matrix = member->Type()->As<core::type::Matrix>();
+ if (!matrix) {
+ continue;
+ }
+ auto* attr = ast::GetAttribute<ast::RowMajorAttribute>(
+ member->Declaration()->attributes);
+ if (!attr) {
+ continue;
+ }
+ // We've got a struct member of a matrix type with a row-major memory layout.
+ // Transpose it, remove the @row_major attribute, and record it in the set.
+ auto transposed_matrix = b.ty.mat(CreateASTTypeFor(ctx, matrix->Type()),
+ matrix->Rows(), matrix->Columns());
+ ctx.Remove(member->Declaration()->attributes, attr);
+ auto* replacement = b.Member(ctx.Clone(member->Name()), transposed_matrix,
+ ctx.Clone(member->Declaration()->attributes));
+ ctx.Replace(member->Declaration(), replacement);
+ transposed_members.Add(member);
+ }
+ }
+ }
+
+ if (transposed_members.IsEmpty()) {
+ return SkipTransform;
+ }
+
+ // Look for expressions that access the matrix.
+ // The `row_major_accesses` map tracks expressions that are accessing a transposed matrix
+ // (or a subset of it), and the originating struct member extraction they came from.
+ Hashmap<const sem::ValueExpression*, const sem::StructMemberAccess*, 8> row_major_accesses;
+ for (auto* node : src.ASTNodes().Objects()) {
+ // Check for assignments to all or part of a transposed matrix and replace them.
+ if (auto* assign = node->As<ast::AssignmentStatement>()) {
+ auto* lhs = src.Sem().GetVal(assign->lhs);
+ auto row_major_access = row_major_accesses.Get(lhs);
+ if (row_major_access) {
+ ReplaceAssignment(assign);
+ row_major_accesses.Remove(lhs);
+ }
+ }
+
+ auto* sem_expr = sem.GetVal(node);
+ if (!sem_expr) {
+ continue;
+ }
+
+ if (auto* accessor = sem_expr->UnwrapLoad()->As<sem::AccessorExpression>()) {
+ if (auto* member_access = accessor->As<sem::StructMemberAccess>()) {
+ // Check if we are accessing a struct member that is a transposed matrix.
+ if (transposed_members.Contains(member_access->Member())) {
+ if (member_access->Type()->Is<core::type::MemoryView>()) {
+ // This is a pointer, so track the access until we hit a load or store.
+ row_major_accesses.Add(member_access, member_access);
+ } else {
+ // This is not a pointer, so we are extracting a matrix from a value
+ // type. Transpose the matrix now so that all child expressions behave
+ // as expected.
+ ctx.Replace(
+ member_access->Declaration(),
+ b.Call("transpose", ctx.Clone(member_access->Declaration())));
+ }
+ }
+ } else {
+ // For non-struct-member accesses, check if the base object is a transposed
+ // matrix and track the resulting sub-expression if so.
+ auto row_major_access = row_major_accesses.Get(accessor->Object());
+ if (row_major_access) {
+ row_major_accesses.Add(accessor, *row_major_access.value);
+ row_major_accesses.Remove(accessor->Object());
+ }
+ }
+ }
+
+ // Check for loads from all or part of a transposed matrix and replace them.
+ if (auto* load = sem_expr->As<sem::Load>()) {
+ auto row_major_access = row_major_accesses.Get(load->Source());
+ if (row_major_access) {
+ ReplaceLoad(load);
+ row_major_accesses.Remove(load->Source());
+ }
+ }
+
+ // Check for constructors of structures that contain transposed matrices, and transpose
+ // the relevant arguments.
+ if (auto* call = sem_expr->As<sem::Call>()) {
+ if (auto* str = call->Type()->As<core::type::Struct>()) {
+ for (uint32_t i = 0; i < str->Members().Length(); i++) {
+ if (transposed_members.Contains(str->Members()[i])) {
+ auto* arg = call->Arguments()[i]->Declaration();
+ ctx.Replace(arg, b.Call("transpose", ctx.Clone(arg)));
+ }
+ }
+ }
+ }
+ }
+
+ ctx.Clone();
+ return resolver::Resolve(b);
+ }
+
+ /// Replace an assignment to a transposed matrix.
+ /// @param assign the assignment statement to replace
+ void ReplaceAssignment(const ast::AssignmentStatement* assign) {
+ auto* lhs = src.Sem().GetVal(assign->lhs);
+ Switch(
+ src.Sem().GetVal(assign->rhs)->Type(),
+ [&](const core::type::Matrix*) {
+ // We are storing the whole matrix, so just transpose the RHS.
+ ctx.Replace(assign->rhs, b.Call("transpose", ctx.Clone(assign->rhs)));
+ },
+ [&](const core::type::Vector*) {
+ // We are storing a single column, which has to be done element-wise.
+ // Call a helper function to do this.
+ auto* col_access = lhs->As<sem::IndexAccessorExpression>();
+ TINT_ASSERT(col_access);
+ auto* to = b.AddressOf(ctx.Clone(col_access->Object()->Declaration()));
+ auto* idx = b.Call("u32", ctx.Clone(col_access->Index()->Declaration()));
+ auto* col = ctx.Clone(assign->rhs);
+ ctx.Replace(assign,
+ b.CallStmt(b.Call(StoreColumnHelper(col_access->Object()->Type()), to,
+ idx, col)));
+ },
+ [&](const core::type::Scalar*) {
+ // We are storing a single element, so reconstruct the index accessors with the
+ // column and row indices swapped over.
+ ctx.Replace(assign->lhs, TransposeAccessIndices(lhs));
+ },
+ TINT_ICE_ON_NO_MATCH);
+ }
+
+ /// Get (or create) a helper function that will assign a column to a transposed matrix.
+ /// @param dest_type the matrix type we are storing to
+ /// @returns the name of the helper function
+ Symbol StoreColumnHelper(const core::type::Type* dest_type) {
+ auto* ref_type = dest_type->As<core::type::Reference>();
+ TINT_ASSERT(ref_type);
+ auto* matrix_type = ref_type->StoreType()->As<core::type::Matrix>();
+ TINT_ASSERT(matrix_type);
+ return column_store_helpers.GetOrAdd(ref_type, [&] {
+ // The helper function will look like this:
+ // fn tint_store_row_major(to: ptr<private, mat3x2<f32>>, idx : u32, col: vec3<f32>) {
+ // to[0][idx] = col[0];
+ // to[1][idx] = col[1];
+ // to[2][idx] = col[2];
+ // }
+ auto name = b.Symbols().New("tint_store_row_major_column");
+ auto transposed = b.ty.mat(CreateASTTypeFor(ctx, matrix_type->Type()),
+ matrix_type->Rows(), matrix_type->Columns());
+ auto ptr = b.ty.ptr(ref_type->AddressSpace(), transposed,
+ ref_type->AddressSpace() == core::AddressSpace::kStorage
+ ? ref_type->Access()
+ : core::Access::kUndefined);
+ auto* to = b.Param("tint_to", ptr);
+ auto* idx = b.Param("tint_idx", b.ty.u32());
+ auto* col = b.Param("tint_col", CreateASTTypeFor(ctx, matrix_type->ColumnType()));
+ Vector<const ast::Statement*, 4> body;
+ for (uint32_t i = 0; i < matrix_type->Rows(); i++) {
+ body.Push(b.Assign(b.IndexAccessor(b.IndexAccessor(to, b.Expr(AInt(i))), idx),
+ b.IndexAccessor(col, b.Expr(AInt(i)))));
+ }
+ b.Func(name, Vector{to, idx, col}, {}, std::move(body));
+ return name;
+ });
+ }
+
+ /// Replace a load from a transposed matrix.
+ /// @param load the load expression to replace
+ void ReplaceLoad(const sem::Load* load) {
+ Switch(
+ load->Type(),
+ [&](const core::type::Matrix*) {
+ // We are loading the whole matrix, so just transpose the result.
+ ctx.Replace(load->Declaration(),
+ b.Call("transpose", ctx.Clone(load->Declaration())));
+ },
+ [&](const core::type::Vector*) {
+ // We are loading a single column, which has to be done element-wise.
+ // Call a helper function to do this.
+ auto* col_access = load->Source()->As<sem::IndexAccessorExpression>();
+ TINT_ASSERT(col_access);
+ auto* from = b.AddressOf(ctx.Clone(col_access->Object()->Declaration()));
+ auto* idx = b.Call("u32", ctx.Clone(col_access->Index()->Declaration()));
+ ctx.Replace(load->Declaration(),
+ b.Call(LoadColumnHelper(col_access->Object()->Type()), from, idx));
+ },
+ [&](const core::type::Scalar*) {
+ // We are loading a single element, so reconstruct the index accessors with the
+ // column and row indices swapped over.
+ ctx.Replace(load->Declaration(), TransposeAccessIndices(load->Source()));
+ },
+ TINT_ICE_ON_NO_MATCH);
+ }
+
+ /// Get (or create) a helper function that will load a column from a transposed matrix.
+ /// @param src_type the matrix type we are load from
+ /// @returns the name of the helper function
+ Symbol LoadColumnHelper(const core::type::Type* src_type) {
+ auto* ref_type = src_type->As<core::type::Reference>();
+ TINT_ASSERT(ref_type);
+ auto* matrix_type = ref_type->StoreType()->As<core::type::Matrix>();
+ TINT_ASSERT(matrix_type);
+ return column_load_helpers.GetOrAdd(ref_type, [&] {
+ // The helper function will look like this:
+ // fn tint_load_row_major(from: ptr<private, mat3x2<f32>>, idx : u32) -> vec3<f32> {
+ // return vec3<f32>(to[0][idx], to[1][idx], to[2][idx]);
+ // }
+ auto name = b.Symbols().New("tint_load_row_major_column");
+ auto transposed = b.ty.mat(CreateASTTypeFor(ctx, matrix_type->Type()),
+ matrix_type->Rows(), matrix_type->Columns());
+ auto ptr = b.ty.ptr(ref_type->AddressSpace(), transposed,
+ ref_type->AddressSpace() == core::AddressSpace::kStorage
+ ? ref_type->Access()
+ : core::Access::kUndefined);
+ auto* from = b.Param("tint_from", ptr);
+ auto* idx = b.Param("tint_idx", b.ty.u32());
+ Vector<const ast::Expression*, 4> rows;
+ for (uint32_t i = 0; i < matrix_type->Rows(); i++) {
+ rows.Push(b.IndexAccessor(b.IndexAccessor(from, b.Expr(AInt(i))), idx));
+ }
+ b.Func(name, Vector{from, idx}, CreateASTTypeFor(ctx, matrix_type->ColumnType()),
+ Vector{
+ b.Return(b.Call(CreateASTTypeFor(ctx, matrix_type->ColumnType()),
+ std::move(rows))),
+ });
+ return name;
+ });
+ }
+
+ /// Swap the column and row indices for a matrix element accessor chain.
+ /// @param expr the accessor expression to transpose
+ /// @returns the transposed access chain
+ const ast::Expression* TransposeAccessIndices(const sem::ValueExpression* expr) {
+ auto* row_access = expr->As<sem::AccessorExpression>();
+ TINT_ASSERT(row_access);
+ auto* col_access = row_access->Object()->As<sem::IndexAccessorExpression>();
+ TINT_ASSERT(col_access);
+ auto* matrix = ctx.Clone(col_access->Object()->Declaration());
+ auto* col_idx = ctx.Clone(col_access->Index()->Declaration());
+
+ // The row index could either be a array accessor or a vector component swizzle.
+ const ast::Expression* row_idx = nullptr;
+ if (auto* index = row_access->As<sem::IndexAccessorExpression>()) {
+ row_idx = ctx.Clone(index->Index()->Declaration());
+ } else if (auto* swizzle = row_access->As<sem::Swizzle>()) {
+ row_idx = b.Expr(u32(swizzle->Indices()[0]));
+ } else {
+ TINT_UNREACHABLE();
+ }
+
+ return b.IndexAccessor(b.IndexAccessor(matrix, row_idx), col_idx);
+ }
+};
+
+ast::transform::Transform::ApplyResult TransposeRowMajor::Apply(const Program& src,
+ const ast::transform::DataMap&,
+ ast::transform::DataMap&) const {
+ return State(src).Run();
+}
+
+} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h
new file mode 100644
index 0000000..95a1bfd
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h
@@ -0,0 +1,61 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_SPIRV_READER_AST_LOWER_TRANSPOSE_ROW_MAJOR_H_
+#define SRC_TINT_LANG_SPIRV_READER_AST_LOWER_TRANSPOSE_ROW_MAJOR_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::spirv::reader {
+
+/// TransposeRowMajor replaces matrix members of storage or uniform buffer structures that have a
+/// @row_major attribute,
+///
+/// This transform is used by the SPIR-V reader to handle the SPIR-V RowMajor attribute.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * SimplifyPointers
+class TransposeRowMajor final : public Castable<TransposeRowMajor, ast::transform::Transform> {
+ public:
+ /// Constructor
+ TransposeRowMajor();
+
+ /// Destructor
+ ~TransposeRowMajor() override;
+
+ /// @copydoc ast::transform::Transform::Apply
+ ApplyResult Apply(const Program& program,
+ const ast::transform::DataMap& inputs,
+ ast::transform::DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::spirv::reader
+
+#endif // SRC_TINT_LANG_SPIRV_READER_AST_LOWER_TRANSPOSE_ROW_MAJOR_H_
diff --git a/src/tint/lang/spirv/reader/ast_lower/transpose_row_major_test.cc b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major_test.cc
new file mode 100644
index 0000000..e1e26a9
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major_test.cc
@@ -0,0 +1,985 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/helper_test.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/program/clone_context.h"
+#include "src/tint/lang/wgsl/program/program_builder.h"
+#include "src/tint/lang/wgsl/resolver/resolve.h"
+
+namespace tint::spirv::reader {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+using TransposeRowMajorTest = ast::transform::TransformTest;
+using SimplifyPointers = ast::transform::SimplifyPointers;
+
+TEST_F(TransposeRowMajorTest, ShouldRunEmptyModule) {
+ auto* src = R"()";
+
+ EXPECT_FALSE(ShouldRun<TransposeRowMajor>(src));
+}
+
+TEST_F(TransposeRowMajorTest, ShouldRunColumnMajorMatrix) {
+ auto* src = R"(
+struct S {
+ m : mat3x2<f32>
+}
+
+@group(0) @binding(0)
+var<uniform> s : S;
+)";
+
+ EXPECT_FALSE(ShouldRun<TransposeRowMajor>(src));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformMatrix) {
+ // struct S {
+ // @offset(16)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x3<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("x", b.ty.mat2x3<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : mat2x3<f32> = transpose(s.m);
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformColumn) {
+ // struct S {
+ // @offset(16)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : vec3<f32> = s.m[1];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("x", b.ty.vec3<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+fn tint_load_row_major_column(tint_from : ptr<uniform, mat3x2<f32>>, tint_idx : u32) -> vec3<f32> {
+ return vec3<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx], tint_from[2][tint_idx]);
+}
+
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : vec3<f32> = tint_load_row_major_column(&(s.m), u32(1i));
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformElement_MemberAccessor) {
+ // struct S {
+ // @offset(16)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let col_idx : i32 = 1i;
+ // let x : f32 = s.m[col_idx].z;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+ b.Decl(b.Let(
+ "x", b.ty.f32(),
+ b.MemberAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"), "z"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let col_idx : i32 = 1i;
+ let x : f32 = s.m[2u][col_idx];
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformElement_IndexAccessor) {
+ // struct S {
+ // @offset(16)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let col_idx : i32 = 1i;
+ // let row_idx : i32 = 2i;
+ // let x : f32 = s.m[col_idx][row_idx];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+ b.Decl(b.Let("row_idx", b.ty.i32(), b.Expr(2_i))),
+ b.Decl(b.Let("x", b.ty.f32(),
+ b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"),
+ "row_idx"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let col_idx : i32 = 1i;
+ let row_idx : i32 = 2i;
+ let x : f32 = s.m[row_idx][col_idx];
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformSwizzle) {
+ // struct S {
+ // @offset(16)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let col_idx : i32 = 1i;
+ // let x : vec2<f32> = s.m[1].zx;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+ b.Decl(b.Let(
+ "x", b.ty.vec2<f32>(),
+ b.MemberAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"), "zx"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+fn tint_load_row_major_column(tint_from : ptr<uniform, mat3x2<f32>>, tint_idx : u32) -> vec3<f32> {
+ return vec3<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx], tint_from[2][tint_idx]);
+}
+
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let col_idx : i32 = 1i;
+ let x : vec2<f32> = tint_load_row_major_column(&(s.m), u32(col_idx)).zx;
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, WriteStorageMatrix) {
+ // struct S {
+ // @offset(8)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // s.m = mat2x3<f32>(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(8_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Assign(b.MemberAccessor("s", "m"), b.Call<mat2x3<f32>>(1_f, 2_f, 3_f, 4_f, 5_f, 6_f)),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding_0 : u32,
+ /* @offset(8u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ s.m = transpose(mat2x3<f32>(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f));
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, WriteStorageColumn) {
+ // struct S {
+ // @offset(8)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let col_idx : i32 = 1i;
+ // s.m[1] = vec3<f32>(1.0, 2.0, 3.0);
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(8_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"),
+ b.Call<vec3<f32>>(1_f, 2_f, 3_f)),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+fn tint_store_row_major_column(tint_to : ptr<storage, mat3x2<f32>, read_write>, tint_idx : u32, tint_col : vec3<f32>) {
+ tint_to[0][tint_idx] = tint_col[0];
+ tint_to[1][tint_idx] = tint_col[1];
+ tint_to[2][tint_idx] = tint_col[2];
+}
+
+struct S {
+ @size(8)
+ padding_0 : u32,
+ /* @offset(8u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let col_idx : i32 = 1i;
+ tint_store_row_major_column(&(s.m), u32(col_idx), vec3<f32>(1.0f, 2.0f, 3.0f));
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, WriteStorageElement_MemberAccessor) {
+ // struct S {
+ // @offset(8)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let col_idx : i32 = 1i;
+ // s.m[1].z = 1.0;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(8_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+ b.Assign(b.MemberAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"), "z"),
+ 1_f),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding_0 : u32,
+ /* @offset(8u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let col_idx : i32 = 1i;
+ s.m[2u][col_idx] = 1.0f;
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, WriteStorageElement_IndexAccessor) {
+ // struct S {
+ // @offset(8)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let col_idx : i32 = 1i;
+ // let row_idx : i32 = 2i;
+ // s.m[1][idx] = 1.0;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(8_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+ b.Decl(b.Let("row_idx", b.ty.i32(), b.Expr(2_i))),
+ b.Assign(b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"),
+ "row_idx"),
+ 1_f),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(8)
+ padding_0 : u32,
+ /* @offset(8u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let col_idx : i32 = 1i;
+ let row_idx : i32 = 2i;
+ s.m[row_idx][col_idx] = 1.0f;
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ExtractFromLoadedStruct) {
+ // struct S {
+ // @offset(16)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let col_idx : i32 = 1i;
+ // let row_idx : i32 = 2i;
+ // let load = s;
+ // let m : mat2x3<f32> = load.m;
+ // let c : vec3<f32> = load.m[col_idx];
+ // let e : vec3<f32> = load.m[col_idx][row_idx];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(16_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func(
+ "f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+ b.Decl(b.Let("row_idx", b.ty.i32(), b.Expr(2_i))),
+ b.Decl(b.Let("load", b.ty.Of(S), b.Expr("s"))),
+ b.Decl(b.Let("m", b.ty.mat2x3<f32>(), b.MemberAccessor("load", "m"))),
+ b.Decl(b.Let("c", b.ty.vec3<f32>(),
+ b.IndexAccessor(b.MemberAccessor("load", "m"), "col_idx"))),
+ b.Decl(b.Let("e", b.ty.f32(),
+ b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("load", "m"), "col_idx"),
+ "row_idx"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ @size(16)
+ padding_0 : u32,
+ /* @offset(16u) */
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let col_idx : i32 = 1i;
+ let row_idx : i32 = 2i;
+ let load : S = s;
+ let m : mat2x3<f32> = transpose(load.m);
+ let c : vec3<f32> = transpose(load.m)[col_idx];
+ let e : f32 = transpose(load.m)[col_idx][row_idx];
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, InsertInStructConstructor) {
+ // struct S {
+ // @offset(0) @row_major m1 : mat2x3<f32>,
+ // @offset(32) m2 : mat4x2<f32>,
+ // @offset(64) @row_major m3 : mat4x2<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let m1 = mat2x3<f32>();
+ // let m2 = mat4x2<f32>();
+ // s = S(m, m2, m2);
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(0_u),
+ b.RowMajor(),
+ }),
+ b.Member("m1", b.ty.mat4x2<f32>(),
+ Vector{
+ b.MemberOffset(32_u),
+ }),
+ b.Member("m2", b.ty.mat4x2<f32>(),
+ Vector{
+ b.MemberOffset(64_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("m1", b.Call("mat2x3f"))),
+ b.Decl(b.Let("m2", b.Call("mat4x2f"))),
+ b.Assign("s", b.Call("S", "m1", "m2", "m2")),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ /* @offset(0u) */
+ m : mat3x2<f32>,
+ @size(8)
+ padding_0 : u32,
+ /* @offset(32u) */
+ m1 : mat4x2<f32>,
+ /* @offset(64u) */
+ m2 : mat2x4<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let m1 = mat2x3f();
+ let m2 = mat4x2f();
+ s = S(transpose(m1), m2, transpose(m2));
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, DeeplyNested) {
+ // struct Inner {
+ // @offset(0)
+ // @row_major
+ // m : mat4x3<f32>,
+ // };
+ // struct Outer {
+ // @offset(0)
+ // arr : array<Inner, 4>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> buffer : Outer;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let m = buffer.arr[1].m;
+ // buffer.arr[0].m[3] = m[2];
+ // }
+ ProgramBuilder b;
+ auto* inner = b.Structure("Inner", Vector{
+ b.Member("m", b.ty.mat4x3<f32>(),
+ Vector{
+ b.MemberOffset(0_u),
+ b.RowMajor(),
+ }),
+ });
+ auto* outer = b.Structure("Outer", Vector{
+ b.Member("arr", b.ty.array(b.ty.Of(inner), 4_a),
+ Vector{
+ b.MemberOffset(0_u),
+ }),
+ });
+ b.GlobalVar("buffer", b.ty.Of(outer), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let(
+ "m", b.ty.mat4x3<f32>(),
+ b.MemberAccessor(b.IndexAccessor(b.MemberAccessor("buffer", "arr"), 1_a), "m"))),
+ b.Assign(b.IndexAccessor(
+ b.MemberAccessor(
+ b.IndexAccessor(b.MemberAccessor("buffer", "arr"), 0_a), "m"),
+ 3_a),
+ b.IndexAccessor("m", 2_a)),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+fn tint_store_row_major_column(tint_to : ptr<storage, mat3x4<f32>, read_write>, tint_idx : u32, tint_col : vec3<f32>) {
+ tint_to[0][tint_idx] = tint_col[0];
+ tint_to[1][tint_idx] = tint_col[1];
+ tint_to[2][tint_idx] = tint_col[2];
+}
+
+struct Inner {
+ /* @offset(0u) */
+ m : mat3x4<f32>,
+}
+
+struct Outer {
+ /* @offset(0u) */
+ arr : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : Outer;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let m : mat4x3<f32> = transpose(buffer.arr[1].m);
+ tint_store_row_major_column(&(buffer.arr[0].m), u32(3), m[2]);
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, MultipleColumnHelpers) {
+ // struct S {
+ // @offset(0) @row_major m1 : mat2x3<f32>,
+ // @offset(32) @row_major m2 : mat4x2<f32>,
+ // };
+ // @group(0) @binding(0) var<storage, read_write> s : S;
+ // var<private> ps : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // ps.m1[0] = s.m1[1];
+ // ps.m1[1] = s.m1[0];
+ // ps.m2[2] = s.m1[3];
+ // ps.m2[3] = s.m1[2];
+ //
+ // s.m1[0] = ps.m1[0];
+ // s.m1[1] = ps.m1[1];
+ // s.m2[2] = ps.m2[2];
+ // s.m2[3] = ps.m2[3];
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure("S", Vector{
+ b.Member("m1", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(0_u),
+ b.RowMajor(),
+ }),
+ b.Member("m2", b.ty.mat4x2<f32>(),
+ Vector{
+ b.MemberOffset(32_u),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+ b.Group(0_a), b.Binding(0_a));
+ b.GlobalVar("ps", b.ty.Of(S), core::AddressSpace::kPrivate);
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Assign(b.IndexAccessor(b.MemberAccessor("ps", "m1"), 0_u),
+ b.IndexAccessor(b.MemberAccessor("s", "m1"), 1_u)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("ps", "m1"), 1_u),
+ b.IndexAccessor(b.MemberAccessor("s", "m1"), 0_u)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("ps", "m2"), 2_u),
+ b.IndexAccessor(b.MemberAccessor("s", "m2"), 3_u)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("ps", "m2"), 3_u),
+ b.IndexAccessor(b.MemberAccessor("s", "m2"), 2_u)),
+
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m1"), 0_u),
+ b.IndexAccessor(b.MemberAccessor("ps", "m1"), 0_u)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m1"), 1_u),
+ b.IndexAccessor(b.MemberAccessor("ps", "m1"), 1_u)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m2"), 2_u),
+ b.IndexAccessor(b.MemberAccessor("ps", "m2"), 2_u)),
+ b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m2"), 3_u),
+ b.IndexAccessor(b.MemberAccessor("ps", "m2"), 3_u)),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+fn tint_load_row_major_column(tint_from : ptr<storage, mat3x2<f32>, read_write>, tint_idx : u32) -> vec3<f32> {
+ return vec3<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx], tint_from[2][tint_idx]);
+}
+
+fn tint_store_row_major_column(tint_to : ptr<private, mat3x2<f32>>, tint_idx : u32, tint_col : vec3<f32>) {
+ tint_to[0][tint_idx] = tint_col[0];
+ tint_to[1][tint_idx] = tint_col[1];
+ tint_to[2][tint_idx] = tint_col[2];
+}
+
+fn tint_load_row_major_column_1(tint_from : ptr<storage, mat2x4<f32>, read_write>, tint_idx : u32) -> vec2<f32> {
+ return vec2<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx]);
+}
+
+fn tint_store_row_major_column_1(tint_to : ptr<private, mat2x4<f32>>, tint_idx : u32, tint_col : vec2<f32>) {
+ tint_to[0][tint_idx] = tint_col[0];
+ tint_to[1][tint_idx] = tint_col[1];
+}
+
+fn tint_load_row_major_column_2(tint_from : ptr<private, mat3x2<f32>>, tint_idx : u32) -> vec3<f32> {
+ return vec3<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx], tint_from[2][tint_idx]);
+}
+
+fn tint_store_row_major_column_2(tint_to : ptr<storage, mat3x2<f32>, read_write>, tint_idx : u32, tint_col : vec3<f32>) {
+ tint_to[0][tint_idx] = tint_col[0];
+ tint_to[1][tint_idx] = tint_col[1];
+ tint_to[2][tint_idx] = tint_col[2];
+}
+
+fn tint_load_row_major_column_3(tint_from : ptr<private, mat2x4<f32>>, tint_idx : u32) -> vec2<f32> {
+ return vec2<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx]);
+}
+
+fn tint_store_row_major_column_3(tint_to : ptr<storage, mat2x4<f32>, read_write>, tint_idx : u32, tint_col : vec2<f32>) {
+ tint_to[0][tint_idx] = tint_col[0];
+ tint_to[1][tint_idx] = tint_col[1];
+}
+
+struct S {
+ /* @offset(0u) */
+ m1 : mat3x2<f32>,
+ /* @offset(32u) */
+ m2 : mat2x4<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+var<private> ps : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ tint_store_row_major_column(&(ps.m1), u32(0u), tint_load_row_major_column(&(s.m1), u32(1u)));
+ tint_store_row_major_column(&(ps.m1), u32(1u), tint_load_row_major_column(&(s.m1), u32(0u)));
+ tint_store_row_major_column_1(&(ps.m2), u32(2u), tint_load_row_major_column_1(&(s.m2), u32(3u)));
+ tint_store_row_major_column_1(&(ps.m2), u32(3u), tint_load_row_major_column_1(&(s.m2), u32(2u)));
+ tint_store_row_major_column_2(&(s.m1), u32(0u), tint_load_row_major_column_2(&(ps.m1), u32(0u)));
+ tint_store_row_major_column_2(&(s.m1), u32(1u), tint_load_row_major_column_2(&(ps.m1), u32(1u)));
+ tint_store_row_major_column_3(&(s.m2), u32(2u), tint_load_row_major_column_3(&(ps.m2), u32(2u)));
+ tint_store_row_major_column_3(&(s.m2), u32(3u), tint_load_row_major_column_3(&(ps.m2), u32(3u)));
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, PreserveMatrixStride) {
+ // struct S {
+ // @offset(0)
+ // @stride(32)
+ // @row_major
+ // m : mat2x3<f32>,
+ // };
+ // @group(0) @binding(0) var<uniform> s : S;
+ //
+ // @compute @workgroup_size(1)
+ // fn f() {
+ // let x : mat2x3<f32> = s.m;
+ // }
+ ProgramBuilder b;
+ auto* S = b.Structure(
+ "S", Vector{
+ b.Member("m", b.ty.mat2x3<f32>(),
+ Vector{
+ b.MemberOffset(0_u),
+ b.Stride(32_u),
+ b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+ b.RowMajor(),
+ }),
+ });
+ b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+ b.Func("f", tint::Empty, b.ty.void_(),
+ Vector{
+ b.Decl(b.Let("x", b.ty.mat2x3<f32>(), b.MemberAccessor("s", "m"))),
+ },
+ Vector{
+ b.Stage(ast::PipelineStage::kCompute),
+ b.WorkgroupSize(1_i),
+ });
+
+ auto* expect = R"(
+struct S {
+ /* @offset(0u) */
+ @stride(32) @internal(disable_validation__ignore_stride)
+ m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+ let x : mat2x3<f32> = transpose(s.m);
+}
+)";
+
+ auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::spirv::reader