[spirv-reader] Add TransposeRowMajor transform

Replace structure members that have @row_major attributes with the
transposed equivalents, and then update all relevant accessor
expressions to either transpose the whole matrix or swap accessor
indices as necessary.

Bug: 364267168
Change-Id: I3f25276a6ec96b68bf0bab05ee79b482f3a978e1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/207418
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel b/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
index 7bb61d7..4c3560d 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
@@ -44,6 +44,7 @@
     "decompose_strided_matrix.cc",
     "fold_trivial_lets.cc",
     "pass_workgroup_id_as_argument.cc",
+    "transpose_row_major.cc",
   ],
   hdrs = [
     "atomics.h",
@@ -51,6 +52,7 @@
     "decompose_strided_matrix.h",
     "fold_trivial_lets.h",
     "pass_workgroup_id_as_argument.h",
+    "transpose_row_major.h",
   ],
   deps = [
     "//src/tint/api/common",
@@ -92,6 +94,7 @@
     "decompose_strided_matrix_test.cc",
     "fold_trivial_lets_test.cc",
     "pass_workgroup_id_as_argument_test.cc",
+    "transpose_row_major_test.cc",
   ],
   deps = [
     "//src/tint/api/common",
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake b/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
index 361645f..5a87287 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
@@ -51,6 +51,8 @@
   lang/spirv/reader/ast_lower/fold_trivial_lets.h
   lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
   lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h
+  lang/spirv/reader/ast_lower/transpose_row_major.cc
+  lang/spirv/reader/ast_lower/transpose_row_major.h
 )
 
 tint_target_add_dependencies(tint_lang_spirv_reader_ast_lower lib
@@ -98,6 +100,7 @@
   lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
   lang/spirv/reader/ast_lower/fold_trivial_lets_test.cc
   lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
+  lang/spirv/reader/ast_lower/transpose_row_major_test.cc
 )
 
 tint_target_add_dependencies(tint_lang_spirv_reader_ast_lower_test test
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.gn b/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
index 4acb4a2..6252459 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
@@ -55,6 +55,8 @@
       "fold_trivial_lets.h",
       "pass_workgroup_id_as_argument.cc",
       "pass_workgroup_id_as_argument.h",
+      "transpose_row_major.cc",
+      "transpose_row_major.h",
     ]
     deps = [
       "${dawn_root}/src/utils:utils",
@@ -96,6 +98,7 @@
         "decompose_strided_matrix_test.cc",
         "fold_trivial_lets_test.cc",
         "pass_workgroup_id_as_argument_test.cc",
+        "transpose_row_major_test.cc",
       ]
       deps = [
         "${dawn_root}/src/utils:utils",
diff --git a/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.cc b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.cc
new file mode 100644
index 0000000..f77c7c1
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.cc
@@ -0,0 +1,344 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/program/clone_context.h"
+#include "src/tint/lang/wgsl/program/program_builder.h"
+#include "src/tint/lang/wgsl/resolver/resolve.h"
+#include "src/tint/lang/wgsl/sem/call.h"
+#include "src/tint/lang/wgsl/sem/index_accessor_expression.h"
+#include "src/tint/lang/wgsl/sem/load.h"
+#include "src/tint/lang/wgsl/sem/member_accessor_expression.h"
+#include "src/tint/utils/rtti/switch.h"
+
+using namespace tint::core::fluent_types;  // NOLINT
+
+TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::TransposeRowMajor);
+
+namespace tint::spirv::reader {
+
+TransposeRowMajor::TransposeRowMajor() = default;
+
+TransposeRowMajor::~TransposeRowMajor() = default;
+
+/// PIMPL state for the transform.
+struct TransposeRowMajor::State {
+    /// The source program
+    const Program& src;
+    /// The target program builder
+    ProgramBuilder b;
+    /// The clone context
+    program::CloneContext ctx = {&b, &src, /* auto_clone_symbols */ true};
+    /// The semantic info.
+    const sem::Info& sem = src.Sem();
+
+    /// Map from matrix reference to column load helper function.
+    Hashmap<const core::type::Type*, Symbol, 4> column_load_helpers;
+
+    /// Map from matrix reference to column store helper function.
+    Hashmap<const core::type::Type*, Symbol, 4> column_store_helpers;
+
+    /// Constructor
+    /// @param program the source program
+    explicit State(const Program& program) : src(program) {}
+
+    ApplyResult Run() {
+        // Scan the program for all storage and uniform structure matrix members with the @row_major
+        // attribute. Replace these matrices with a transposed version, and populate the
+        // `transposed_members` map with the members that have been replaced.
+        Hashset<const core::type::StructMember*, 8> transposed_members;
+        for (auto* decl : src.AST().TypeDecls()) {
+            if (auto* str = decl->As<ast::Struct>()) {
+                auto* str_ty = src.Sem().Get(str);
+                if (!str_ty->UsedAs(core::AddressSpace::kUniform) &&
+                    !str_ty->UsedAs(core::AddressSpace::kStorage)) {
+                    continue;
+                }
+                for (auto* member : str_ty->Members()) {
+                    auto* matrix = member->Type()->As<core::type::Matrix>();
+                    if (!matrix) {
+                        continue;
+                    }
+                    auto* attr = ast::GetAttribute<ast::RowMajorAttribute>(
+                        member->Declaration()->attributes);
+                    if (!attr) {
+                        continue;
+                    }
+                    // We've got a struct member of a matrix type with a row-major memory layout.
+                    // Transpose it, remove the @row_major attribute, and record it in the set.
+                    auto transposed_matrix = b.ty.mat(CreateASTTypeFor(ctx, matrix->Type()),
+                                                      matrix->Rows(), matrix->Columns());
+                    ctx.Remove(member->Declaration()->attributes, attr);
+                    auto* replacement = b.Member(ctx.Clone(member->Name()), transposed_matrix,
+                                                 ctx.Clone(member->Declaration()->attributes));
+                    ctx.Replace(member->Declaration(), replacement);
+                    transposed_members.Add(member);
+                }
+            }
+        }
+
+        if (transposed_members.IsEmpty()) {
+            return SkipTransform;
+        }
+
+        // Look for expressions that access the matrix.
+        // The `row_major_accesses` map tracks expressions that are accessing a transposed matrix
+        // (or a subset of it), and the originating struct member extraction they came from.
+        Hashmap<const sem::ValueExpression*, const sem::StructMemberAccess*, 8> row_major_accesses;
+        for (auto* node : src.ASTNodes().Objects()) {
+            // Check for assignments to all or part of a transposed matrix and replace them.
+            if (auto* assign = node->As<ast::AssignmentStatement>()) {
+                auto* lhs = src.Sem().GetVal(assign->lhs);
+                auto row_major_access = row_major_accesses.Get(lhs);
+                if (row_major_access) {
+                    ReplaceAssignment(assign);
+                    row_major_accesses.Remove(lhs);
+                }
+            }
+
+            auto* sem_expr = sem.GetVal(node);
+            if (!sem_expr) {
+                continue;
+            }
+
+            if (auto* accessor = sem_expr->UnwrapLoad()->As<sem::AccessorExpression>()) {
+                if (auto* member_access = accessor->As<sem::StructMemberAccess>()) {
+                    // Check if we are accessing a struct member that is a transposed matrix.
+                    if (transposed_members.Contains(member_access->Member())) {
+                        if (member_access->Type()->Is<core::type::MemoryView>()) {
+                            // This is a pointer, so track the access until we hit a load or store.
+                            row_major_accesses.Add(member_access, member_access);
+                        } else {
+                            // This is not a pointer, so we are extracting a matrix from a value
+                            // type. Transpose the matrix now so that all child expressions behave
+                            // as expected.
+                            ctx.Replace(
+                                member_access->Declaration(),
+                                b.Call("transpose", ctx.Clone(member_access->Declaration())));
+                        }
+                    }
+                } else {
+                    // For non-struct-member accesses, check if the base object is a transposed
+                    // matrix and track the resulting sub-expression if so.
+                    auto row_major_access = row_major_accesses.Get(accessor->Object());
+                    if (row_major_access) {
+                        row_major_accesses.Add(accessor, *row_major_access.value);
+                        row_major_accesses.Remove(accessor->Object());
+                    }
+                }
+            }
+
+            // Check for loads from all or part of a transposed matrix and replace them.
+            if (auto* load = sem_expr->As<sem::Load>()) {
+                auto row_major_access = row_major_accesses.Get(load->Source());
+                if (row_major_access) {
+                    ReplaceLoad(load);
+                    row_major_accesses.Remove(load->Source());
+                }
+            }
+
+            // Check for constructors of structures that contain transposed matrices, and transpose
+            // the relevant arguments.
+            if (auto* call = sem_expr->As<sem::Call>()) {
+                if (auto* str = call->Type()->As<core::type::Struct>()) {
+                    for (uint32_t i = 0; i < str->Members().Length(); i++) {
+                        if (transposed_members.Contains(str->Members()[i])) {
+                            auto* arg = call->Arguments()[i]->Declaration();
+                            ctx.Replace(arg, b.Call("transpose", ctx.Clone(arg)));
+                        }
+                    }
+                }
+            }
+        }
+
+        ctx.Clone();
+        return resolver::Resolve(b);
+    }
+
+    /// Replace an assignment to a transposed matrix.
+    /// @param assign the assignment statement to replace
+    void ReplaceAssignment(const ast::AssignmentStatement* assign) {
+        auto* lhs = src.Sem().GetVal(assign->lhs);
+        Switch(
+            src.Sem().GetVal(assign->rhs)->Type(),
+            [&](const core::type::Matrix*) {
+                // We are storing the whole matrix, so just transpose the RHS.
+                ctx.Replace(assign->rhs, b.Call("transpose", ctx.Clone(assign->rhs)));
+            },
+            [&](const core::type::Vector*) {
+                // We are storing a single column, which has to be done element-wise.
+                // Call a helper function to do this.
+                auto* col_access = lhs->As<sem::IndexAccessorExpression>();
+                TINT_ASSERT(col_access);
+                auto* to = b.AddressOf(ctx.Clone(col_access->Object()->Declaration()));
+                auto* idx = b.Call("u32", ctx.Clone(col_access->Index()->Declaration()));
+                auto* col = ctx.Clone(assign->rhs);
+                ctx.Replace(assign,
+                            b.CallStmt(b.Call(StoreColumnHelper(col_access->Object()->Type()), to,
+                                              idx, col)));
+            },
+            [&](const core::type::Scalar*) {
+                // We are storing a single element, so reconstruct the index accessors with the
+                // column and row indices swapped over.
+                ctx.Replace(assign->lhs, TransposeAccessIndices(lhs));
+            },
+            TINT_ICE_ON_NO_MATCH);
+    }
+
+    /// Get (or create) a helper function that will assign a column to a transposed matrix.
+    /// @param dest_type the matrix type we are storing to
+    /// @returns the name of the helper function
+    Symbol StoreColumnHelper(const core::type::Type* dest_type) {
+        auto* ref_type = dest_type->As<core::type::Reference>();
+        TINT_ASSERT(ref_type);
+        auto* matrix_type = ref_type->StoreType()->As<core::type::Matrix>();
+        TINT_ASSERT(matrix_type);
+        return column_store_helpers.GetOrAdd(ref_type, [&] {
+            // The helper function will look like this:
+            //   fn tint_store_row_major(to: ptr<private, mat3x2<f32>>, idx : u32, col: vec3<f32>) {
+            //     to[0][idx] = col[0];
+            //     to[1][idx] = col[1];
+            //     to[2][idx] = col[2];
+            //   }
+            auto name = b.Symbols().New("tint_store_row_major_column");
+            auto transposed = b.ty.mat(CreateASTTypeFor(ctx, matrix_type->Type()),
+                                       matrix_type->Rows(), matrix_type->Columns());
+            auto ptr = b.ty.ptr(ref_type->AddressSpace(), transposed,
+                                ref_type->AddressSpace() == core::AddressSpace::kStorage
+                                    ? ref_type->Access()
+                                    : core::Access::kUndefined);
+            auto* to = b.Param("tint_to", ptr);
+            auto* idx = b.Param("tint_idx", b.ty.u32());
+            auto* col = b.Param("tint_col", CreateASTTypeFor(ctx, matrix_type->ColumnType()));
+            Vector<const ast::Statement*, 4> body;
+            for (uint32_t i = 0; i < matrix_type->Rows(); i++) {
+                body.Push(b.Assign(b.IndexAccessor(b.IndexAccessor(to, b.Expr(AInt(i))), idx),
+                                   b.IndexAccessor(col, b.Expr(AInt(i)))));
+            }
+            b.Func(name, Vector{to, idx, col}, {}, std::move(body));
+            return name;
+        });
+    }
+
+    /// Replace a load from a transposed matrix.
+    /// @param load the load expression to replace
+    void ReplaceLoad(const sem::Load* load) {
+        Switch(
+            load->Type(),
+            [&](const core::type::Matrix*) {
+                // We are loading the whole matrix, so just transpose the result.
+                ctx.Replace(load->Declaration(),
+                            b.Call("transpose", ctx.Clone(load->Declaration())));
+            },
+            [&](const core::type::Vector*) {
+                // We are loading a single column, which has to be done element-wise.
+                // Call a helper function to do this.
+                auto* col_access = load->Source()->As<sem::IndexAccessorExpression>();
+                TINT_ASSERT(col_access);
+                auto* from = b.AddressOf(ctx.Clone(col_access->Object()->Declaration()));
+                auto* idx = b.Call("u32", ctx.Clone(col_access->Index()->Declaration()));
+                ctx.Replace(load->Declaration(),
+                            b.Call(LoadColumnHelper(col_access->Object()->Type()), from, idx));
+            },
+            [&](const core::type::Scalar*) {
+                // We are loading a single element, so reconstruct the index accessors with the
+                // column and row indices swapped over.
+                ctx.Replace(load->Declaration(), TransposeAccessIndices(load->Source()));
+            },
+            TINT_ICE_ON_NO_MATCH);
+    }
+
+    /// Get (or create) a helper function that will load a column from a transposed matrix.
+    /// @param src_type the matrix type we are load from
+    /// @returns the name of the helper function
+    Symbol LoadColumnHelper(const core::type::Type* src_type) {
+        auto* ref_type = src_type->As<core::type::Reference>();
+        TINT_ASSERT(ref_type);
+        auto* matrix_type = ref_type->StoreType()->As<core::type::Matrix>();
+        TINT_ASSERT(matrix_type);
+        return column_load_helpers.GetOrAdd(ref_type, [&] {
+            // The helper function will look like this:
+            //   fn tint_load_row_major(from: ptr<private, mat3x2<f32>>, idx : u32) -> vec3<f32> {
+            //     return vec3<f32>(to[0][idx], to[1][idx], to[2][idx]);
+            //   }
+            auto name = b.Symbols().New("tint_load_row_major_column");
+            auto transposed = b.ty.mat(CreateASTTypeFor(ctx, matrix_type->Type()),
+                                       matrix_type->Rows(), matrix_type->Columns());
+            auto ptr = b.ty.ptr(ref_type->AddressSpace(), transposed,
+                                ref_type->AddressSpace() == core::AddressSpace::kStorage
+                                    ? ref_type->Access()
+                                    : core::Access::kUndefined);
+            auto* from = b.Param("tint_from", ptr);
+            auto* idx = b.Param("tint_idx", b.ty.u32());
+            Vector<const ast::Expression*, 4> rows;
+            for (uint32_t i = 0; i < matrix_type->Rows(); i++) {
+                rows.Push(b.IndexAccessor(b.IndexAccessor(from, b.Expr(AInt(i))), idx));
+            }
+            b.Func(name, Vector{from, idx}, CreateASTTypeFor(ctx, matrix_type->ColumnType()),
+                   Vector{
+                       b.Return(b.Call(CreateASTTypeFor(ctx, matrix_type->ColumnType()),
+                                       std::move(rows))),
+                   });
+            return name;
+        });
+    }
+
+    /// Swap the column and row indices for a matrix element accessor chain.
+    /// @param expr the accessor expression to transpose
+    /// @returns the transposed access chain
+    const ast::Expression* TransposeAccessIndices(const sem::ValueExpression* expr) {
+        auto* row_access = expr->As<sem::AccessorExpression>();
+        TINT_ASSERT(row_access);
+        auto* col_access = row_access->Object()->As<sem::IndexAccessorExpression>();
+        TINT_ASSERT(col_access);
+        auto* matrix = ctx.Clone(col_access->Object()->Declaration());
+        auto* col_idx = ctx.Clone(col_access->Index()->Declaration());
+
+        // The row index could either be a array accessor or a vector component swizzle.
+        const ast::Expression* row_idx = nullptr;
+        if (auto* index = row_access->As<sem::IndexAccessorExpression>()) {
+            row_idx = ctx.Clone(index->Index()->Declaration());
+        } else if (auto* swizzle = row_access->As<sem::Swizzle>()) {
+            row_idx = b.Expr(u32(swizzle->Indices()[0]));
+        } else {
+            TINT_UNREACHABLE();
+        }
+
+        return b.IndexAccessor(b.IndexAccessor(matrix, row_idx), col_idx);
+    }
+};
+
+ast::transform::Transform::ApplyResult TransposeRowMajor::Apply(const Program& src,
+                                                                const ast::transform::DataMap&,
+                                                                ast::transform::DataMap&) const {
+    return State(src).Run();
+}
+
+}  // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h
new file mode 100644
index 0000000..95a1bfd
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h
@@ -0,0 +1,61 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_SPIRV_READER_AST_LOWER_TRANSPOSE_ROW_MAJOR_H_
+#define SRC_TINT_LANG_SPIRV_READER_AST_LOWER_TRANSPOSE_ROW_MAJOR_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::spirv::reader {
+
+/// TransposeRowMajor replaces matrix members of storage or uniform buffer structures that have a
+/// @row_major attribute,
+///
+/// This transform is used by the SPIR-V reader to handle the SPIR-V RowMajor attribute.
+///
+/// @note Depends on the following transforms to have been run first:
+/// * SimplifyPointers
+class TransposeRowMajor final : public Castable<TransposeRowMajor, ast::transform::Transform> {
+  public:
+    /// Constructor
+    TransposeRowMajor();
+
+    /// Destructor
+    ~TransposeRowMajor() override;
+
+    /// @copydoc ast::transform::Transform::Apply
+    ApplyResult Apply(const Program& program,
+                      const ast::transform::DataMap& inputs,
+                      ast::transform::DataMap& outputs) const override;
+
+  private:
+    struct State;
+};
+
+}  // namespace tint::spirv::reader
+
+#endif  // SRC_TINT_LANG_SPIRV_READER_AST_LOWER_TRANSPOSE_ROW_MAJOR_H_
diff --git a/src/tint/lang/spirv/reader/ast_lower/transpose_row_major_test.cc b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major_test.cc
new file mode 100644
index 0000000..e1e26a9
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/transpose_row_major_test.cc
@@ -0,0 +1,985 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/spirv/reader/ast_lower/transpose_row_major.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/ast/transform/helper_test.h"
+#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
+#include "src/tint/lang/wgsl/program/clone_context.h"
+#include "src/tint/lang/wgsl/program/program_builder.h"
+#include "src/tint/lang/wgsl/resolver/resolve.h"
+
+namespace tint::spirv::reader {
+namespace {
+
+using namespace tint::core::fluent_types;     // NOLINT
+using namespace tint::core::number_suffixes;  // NOLINT
+
+using TransposeRowMajorTest = ast::transform::TransformTest;
+using SimplifyPointers = ast::transform::SimplifyPointers;
+
+TEST_F(TransposeRowMajorTest, ShouldRunEmptyModule) {
+    auto* src = R"()";
+
+    EXPECT_FALSE(ShouldRun<TransposeRowMajor>(src));
+}
+
+TEST_F(TransposeRowMajorTest, ShouldRunColumnMajorMatrix) {
+    auto* src = R"(
+struct S {
+  m : mat3x2<f32>
+}
+
+@group(0) @binding(0)
+var<uniform> s : S;
+)";
+
+    EXPECT_FALSE(ShouldRun<TransposeRowMajor>(src));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformMatrix) {
+    // struct S {
+    //   @offset(16)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let x : mat2x3<f32> = s.m;
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(16_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("x", b.ty.mat2x3<f32>(), b.MemberAccessor("s", "m"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let x : mat2x3<f32> = transpose(s.m);
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformColumn) {
+    // struct S {
+    //   @offset(16)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let x : vec3<f32> = s.m[1];
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(16_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func(
+        "f", tint::Empty, b.ty.void_(),
+        Vector{
+            b.Decl(b.Let("x", b.ty.vec3<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
+        },
+        Vector{
+            b.Stage(ast::PipelineStage::kCompute),
+            b.WorkgroupSize(1_i),
+        });
+
+    auto* expect = R"(
+fn tint_load_row_major_column(tint_from : ptr<uniform, mat3x2<f32>>, tint_idx : u32) -> vec3<f32> {
+  return vec3<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx], tint_from[2][tint_idx]);
+}
+
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let x : vec3<f32> = tint_load_row_major_column(&(s.m), u32(1i));
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformElement_MemberAccessor) {
+    // struct S {
+    //   @offset(16)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let col_idx : i32 = 1i;
+    //   let x : f32 = s.m[col_idx].z;
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(16_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+               b.Decl(b.Let(
+                   "x", b.ty.f32(),
+                   b.MemberAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"), "z"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let col_idx : i32 = 1i;
+  let x : f32 = s.m[2u][col_idx];
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformElement_IndexAccessor) {
+    // struct S {
+    //   @offset(16)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let col_idx : i32 = 1i;
+    //   let row_idx : i32 = 2i;
+    //   let x : f32 = s.m[col_idx][row_idx];
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(16_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+               b.Decl(b.Let("row_idx", b.ty.i32(), b.Expr(2_i))),
+               b.Decl(b.Let("x", b.ty.f32(),
+                            b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"),
+                                            "row_idx"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let col_idx : i32 = 1i;
+  let row_idx : i32 = 2i;
+  let x : f32 = s.m[row_idx][col_idx];
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ReadUniformSwizzle) {
+    // struct S {
+    //   @offset(16)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let col_idx : i32 = 1i;
+    //   let x : vec2<f32> = s.m[1].zx;
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(16_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+               b.Decl(b.Let(
+                   "x", b.ty.vec2<f32>(),
+                   b.MemberAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"), "zx"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+fn tint_load_row_major_column(tint_from : ptr<uniform, mat3x2<f32>>, tint_idx : u32) -> vec3<f32> {
+  return vec3<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx], tint_from[2][tint_idx]);
+}
+
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let col_idx : i32 = 1i;
+  let x : vec2<f32> = tint_load_row_major_column(&(s.m), u32(col_idx)).zx;
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, WriteStorageMatrix) {
+    // struct S {
+    //   @offset(8)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<storage, read_write> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   s.m = mat2x3<f32>(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(8_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.Func(
+        "f", tint::Empty, b.ty.void_(),
+        Vector{
+            b.Assign(b.MemberAccessor("s", "m"), b.Call<mat2x3<f32>>(1_f, 2_f, 3_f, 4_f, 5_f, 6_f)),
+        },
+        Vector{
+            b.Stage(ast::PipelineStage::kCompute),
+            b.WorkgroupSize(1_i),
+        });
+
+    auto* expect = R"(
+struct S {
+  @size(8)
+  padding_0 : u32,
+  /* @offset(8u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  s.m = transpose(mat2x3<f32>(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f));
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, WriteStorageColumn) {
+    // struct S {
+    //   @offset(8)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<storage, read_write> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let col_idx : i32 = 1i;
+    //   s.m[1] = vec3<f32>(1.0, 2.0, 3.0);
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(8_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+               b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"),
+                        b.Call<vec3<f32>>(1_f, 2_f, 3_f)),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+fn tint_store_row_major_column(tint_to : ptr<storage, mat3x2<f32>, read_write>, tint_idx : u32, tint_col : vec3<f32>) {
+  tint_to[0][tint_idx] = tint_col[0];
+  tint_to[1][tint_idx] = tint_col[1];
+  tint_to[2][tint_idx] = tint_col[2];
+}
+
+struct S {
+  @size(8)
+  padding_0 : u32,
+  /* @offset(8u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let col_idx : i32 = 1i;
+  tint_store_row_major_column(&(s.m), u32(col_idx), vec3<f32>(1.0f, 2.0f, 3.0f));
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, WriteStorageElement_MemberAccessor) {
+    // struct S {
+    //   @offset(8)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<storage, read_write> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let col_idx : i32 = 1i;
+    //   s.m[1].z = 1.0;
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(8_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.Func(
+        "f", tint::Empty, b.ty.void_(),
+        Vector{
+            b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+            b.Assign(b.MemberAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"), "z"),
+                     1_f),
+        },
+        Vector{
+            b.Stage(ast::PipelineStage::kCompute),
+            b.WorkgroupSize(1_i),
+        });
+
+    auto* expect = R"(
+struct S {
+  @size(8)
+  padding_0 : u32,
+  /* @offset(8u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let col_idx : i32 = 1i;
+  s.m[2u][col_idx] = 1.0f;
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, WriteStorageElement_IndexAccessor) {
+    // struct S {
+    //   @offset(8)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<storage, read_write> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let col_idx : i32 = 1i;
+    //   let row_idx : i32 = 2i;
+    //   s.m[1][idx] = 1.0;
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(8_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+               b.Decl(b.Let("row_idx", b.ty.i32(), b.Expr(2_i))),
+               b.Assign(b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "m"), "col_idx"),
+                                        "row_idx"),
+                        1_f),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+struct S {
+  @size(8)
+  padding_0 : u32,
+  /* @offset(8u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let col_idx : i32 = 1i;
+  let row_idx : i32 = 2i;
+  s.m[row_idx][col_idx] = 1.0f;
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, ExtractFromLoadedStruct) {
+    // struct S {
+    //   @offset(16)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let col_idx : i32 = 1i;
+    //   let row_idx : i32 = 2i;
+    //   let load = s;
+    //   let m : mat2x3<f32> = load.m;
+    //   let c : vec3<f32> = load.m[col_idx];
+    //   let e : vec3<f32> = load.m[col_idx][row_idx];
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(16_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func(
+        "f", tint::Empty, b.ty.void_(),
+        Vector{
+            b.Decl(b.Let("col_idx", b.ty.i32(), b.Expr(1_i))),
+            b.Decl(b.Let("row_idx", b.ty.i32(), b.Expr(2_i))),
+            b.Decl(b.Let("load", b.ty.Of(S), b.Expr("s"))),
+            b.Decl(b.Let("m", b.ty.mat2x3<f32>(), b.MemberAccessor("load", "m"))),
+            b.Decl(b.Let("c", b.ty.vec3<f32>(),
+                         b.IndexAccessor(b.MemberAccessor("load", "m"), "col_idx"))),
+            b.Decl(b.Let("e", b.ty.f32(),
+                         b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("load", "m"), "col_idx"),
+                                         "row_idx"))),
+        },
+        Vector{
+            b.Stage(ast::PipelineStage::kCompute),
+            b.WorkgroupSize(1_i),
+        });
+
+    auto* expect = R"(
+struct S {
+  @size(16)
+  padding_0 : u32,
+  /* @offset(16u) */
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let col_idx : i32 = 1i;
+  let row_idx : i32 = 2i;
+  let load : S = s;
+  let m : mat2x3<f32> = transpose(load.m);
+  let c : vec3<f32> = transpose(load.m)[col_idx];
+  let e : f32 = transpose(load.m)[col_idx][row_idx];
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, InsertInStructConstructor) {
+    // struct S {
+    //   @offset(0) @row_major m1 : mat2x3<f32>,
+    //   @offset(32) m2 : mat4x2<f32>,
+    //   @offset(64) @row_major m3 : mat4x2<f32>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let m1 = mat2x3<f32>();
+    //   let m2 = mat4x2<f32>();
+    //   s = S(m, m2, m2);
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(0_u),
+                                                b.RowMajor(),
+                                            }),
+                                   b.Member("m1", b.ty.mat4x2<f32>(),
+                                            Vector{
+                                                b.MemberOffset(32_u),
+                                            }),
+                                   b.Member("m2", b.ty.mat4x2<f32>(),
+                                            Vector{
+                                                b.MemberOffset(64_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("m1", b.Call("mat2x3f"))),
+               b.Decl(b.Let("m2", b.Call("mat4x2f"))),
+               b.Assign("s", b.Call("S", "m1", "m2", "m2")),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+struct S {
+  /* @offset(0u) */
+  m : mat3x2<f32>,
+  @size(8)
+  padding_0 : u32,
+  /* @offset(32u) */
+  m1 : mat4x2<f32>,
+  /* @offset(64u) */
+  m2 : mat2x4<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let m1 = mat2x3f();
+  let m2 = mat4x2f();
+  s = S(transpose(m1), m2, transpose(m2));
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, DeeplyNested) {
+    // struct Inner {
+    //   @offset(0)
+    //   @row_major
+    //   m : mat4x3<f32>,
+    // };
+    // struct Outer {
+    //   @offset(0)
+    //   arr : array<Inner, 4>,
+    // };
+    // @group(0) @binding(0) var<storage, read_write> buffer : Outer;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let m = buffer.arr[1].m;
+    //   buffer.arr[0].m[3] = m[2];
+    // }
+    ProgramBuilder b;
+    auto* inner = b.Structure("Inner", Vector{
+                                           b.Member("m", b.ty.mat4x3<f32>(),
+                                                    Vector{
+                                                        b.MemberOffset(0_u),
+                                                        b.RowMajor(),
+                                                    }),
+                                       });
+    auto* outer = b.Structure("Outer", Vector{
+                                           b.Member("arr", b.ty.array(b.ty.Of(inner), 4_a),
+                                                    Vector{
+                                                        b.MemberOffset(0_u),
+                                                    }),
+                                       });
+    b.GlobalVar("buffer", b.ty.Of(outer), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let(
+                   "m", b.ty.mat4x3<f32>(),
+                   b.MemberAccessor(b.IndexAccessor(b.MemberAccessor("buffer", "arr"), 1_a), "m"))),
+               b.Assign(b.IndexAccessor(
+                            b.MemberAccessor(
+                                b.IndexAccessor(b.MemberAccessor("buffer", "arr"), 0_a), "m"),
+                            3_a),
+                        b.IndexAccessor("m", 2_a)),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+fn tint_store_row_major_column(tint_to : ptr<storage, mat3x4<f32>, read_write>, tint_idx : u32, tint_col : vec3<f32>) {
+  tint_to[0][tint_idx] = tint_col[0];
+  tint_to[1][tint_idx] = tint_col[1];
+  tint_to[2][tint_idx] = tint_col[2];
+}
+
+struct Inner {
+  /* @offset(0u) */
+  m : mat3x4<f32>,
+}
+
+struct Outer {
+  /* @offset(0u) */
+  arr : array<Inner, 4>,
+}
+
+@group(0) @binding(0) var<storage, read_write> buffer : Outer;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let m : mat4x3<f32> = transpose(buffer.arr[1].m);
+  tint_store_row_major_column(&(buffer.arr[0].m), u32(3), m[2]);
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, MultipleColumnHelpers) {
+    // struct S {
+    //   @offset(0) @row_major m1 : mat2x3<f32>,
+    //   @offset(32) @row_major m2 : mat4x2<f32>,
+    // };
+    // @group(0) @binding(0) var<storage, read_write> s : S;
+    // var<private> ps : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   ps.m1[0] = s.m1[1];
+    //   ps.m1[1] = s.m1[0];
+    //   ps.m2[2] = s.m1[3];
+    //   ps.m2[3] = s.m1[2];
+    //
+    //   s.m1[0] = ps.m1[0];
+    //   s.m1[1] = ps.m1[1];
+    //   s.m2[2] = ps.m2[2];
+    //   s.m2[3] = ps.m2[3];
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure("S", Vector{
+                                   b.Member("m1", b.ty.mat2x3<f32>(),
+                                            Vector{
+                                                b.MemberOffset(0_u),
+                                                b.RowMajor(),
+                                            }),
+                                   b.Member("m2", b.ty.mat4x2<f32>(),
+                                            Vector{
+                                                b.MemberOffset(32_u),
+                                                b.RowMajor(),
+                                            }),
+                               });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kStorage, core::Access::kReadWrite,
+                b.Group(0_a), b.Binding(0_a));
+    b.GlobalVar("ps", b.ty.Of(S), core::AddressSpace::kPrivate);
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Assign(b.IndexAccessor(b.MemberAccessor("ps", "m1"), 0_u),
+                        b.IndexAccessor(b.MemberAccessor("s", "m1"), 1_u)),
+               b.Assign(b.IndexAccessor(b.MemberAccessor("ps", "m1"), 1_u),
+                        b.IndexAccessor(b.MemberAccessor("s", "m1"), 0_u)),
+               b.Assign(b.IndexAccessor(b.MemberAccessor("ps", "m2"), 2_u),
+                        b.IndexAccessor(b.MemberAccessor("s", "m2"), 3_u)),
+               b.Assign(b.IndexAccessor(b.MemberAccessor("ps", "m2"), 3_u),
+                        b.IndexAccessor(b.MemberAccessor("s", "m2"), 2_u)),
+
+               b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m1"), 0_u),
+                        b.IndexAccessor(b.MemberAccessor("ps", "m1"), 0_u)),
+               b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m1"), 1_u),
+                        b.IndexAccessor(b.MemberAccessor("ps", "m1"), 1_u)),
+               b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m2"), 2_u),
+                        b.IndexAccessor(b.MemberAccessor("ps", "m2"), 2_u)),
+               b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m2"), 3_u),
+                        b.IndexAccessor(b.MemberAccessor("ps", "m2"), 3_u)),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+fn tint_load_row_major_column(tint_from : ptr<storage, mat3x2<f32>, read_write>, tint_idx : u32) -> vec3<f32> {
+  return vec3<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx], tint_from[2][tint_idx]);
+}
+
+fn tint_store_row_major_column(tint_to : ptr<private, mat3x2<f32>>, tint_idx : u32, tint_col : vec3<f32>) {
+  tint_to[0][tint_idx] = tint_col[0];
+  tint_to[1][tint_idx] = tint_col[1];
+  tint_to[2][tint_idx] = tint_col[2];
+}
+
+fn tint_load_row_major_column_1(tint_from : ptr<storage, mat2x4<f32>, read_write>, tint_idx : u32) -> vec2<f32> {
+  return vec2<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx]);
+}
+
+fn tint_store_row_major_column_1(tint_to : ptr<private, mat2x4<f32>>, tint_idx : u32, tint_col : vec2<f32>) {
+  tint_to[0][tint_idx] = tint_col[0];
+  tint_to[1][tint_idx] = tint_col[1];
+}
+
+fn tint_load_row_major_column_2(tint_from : ptr<private, mat3x2<f32>>, tint_idx : u32) -> vec3<f32> {
+  return vec3<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx], tint_from[2][tint_idx]);
+}
+
+fn tint_store_row_major_column_2(tint_to : ptr<storage, mat3x2<f32>, read_write>, tint_idx : u32, tint_col : vec3<f32>) {
+  tint_to[0][tint_idx] = tint_col[0];
+  tint_to[1][tint_idx] = tint_col[1];
+  tint_to[2][tint_idx] = tint_col[2];
+}
+
+fn tint_load_row_major_column_3(tint_from : ptr<private, mat2x4<f32>>, tint_idx : u32) -> vec2<f32> {
+  return vec2<f32>(tint_from[0][tint_idx], tint_from[1][tint_idx]);
+}
+
+fn tint_store_row_major_column_3(tint_to : ptr<storage, mat2x4<f32>, read_write>, tint_idx : u32, tint_col : vec2<f32>) {
+  tint_to[0][tint_idx] = tint_col[0];
+  tint_to[1][tint_idx] = tint_col[1];
+}
+
+struct S {
+  /* @offset(0u) */
+  m1 : mat3x2<f32>,
+  /* @offset(32u) */
+  m2 : mat2x4<f32>,
+}
+
+@group(0) @binding(0) var<storage, read_write> s : S;
+
+var<private> ps : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  tint_store_row_major_column(&(ps.m1), u32(0u), tint_load_row_major_column(&(s.m1), u32(1u)));
+  tint_store_row_major_column(&(ps.m1), u32(1u), tint_load_row_major_column(&(s.m1), u32(0u)));
+  tint_store_row_major_column_1(&(ps.m2), u32(2u), tint_load_row_major_column_1(&(s.m2), u32(3u)));
+  tint_store_row_major_column_1(&(ps.m2), u32(3u), tint_load_row_major_column_1(&(s.m2), u32(2u)));
+  tint_store_row_major_column_2(&(s.m1), u32(0u), tint_load_row_major_column_2(&(ps.m1), u32(0u)));
+  tint_store_row_major_column_2(&(s.m1), u32(1u), tint_load_row_major_column_2(&(ps.m1), u32(1u)));
+  tint_store_row_major_column_3(&(s.m2), u32(2u), tint_load_row_major_column_3(&(ps.m2), u32(2u)));
+  tint_store_row_major_column_3(&(s.m2), u32(3u), tint_load_row_major_column_3(&(ps.m2), u32(3u)));
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(TransposeRowMajorTest, PreserveMatrixStride) {
+    // struct S {
+    //   @offset(0)
+    //   @stride(32)
+    //   @row_major
+    //   m : mat2x3<f32>,
+    // };
+    // @group(0) @binding(0) var<uniform> s : S;
+    //
+    // @compute @workgroup_size(1)
+    // fn f() {
+    //   let x : mat2x3<f32> = s.m;
+    // }
+    ProgramBuilder b;
+    auto* S = b.Structure(
+        "S", Vector{
+                 b.Member("m", b.ty.mat2x3<f32>(),
+                          Vector{
+                              b.MemberOffset(0_u),
+                              b.Stride(32_u),
+                              b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
+                              b.RowMajor(),
+                          }),
+             });
+    b.GlobalVar("s", b.ty.Of(S), core::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
+    b.Func("f", tint::Empty, b.ty.void_(),
+           Vector{
+               b.Decl(b.Let("x", b.ty.mat2x3<f32>(), b.MemberAccessor("s", "m"))),
+           },
+           Vector{
+               b.Stage(ast::PipelineStage::kCompute),
+               b.WorkgroupSize(1_i),
+           });
+
+    auto* expect = R"(
+struct S {
+  /* @offset(0u) */
+  @stride(32) @internal(disable_validation__ignore_stride)
+  m : mat3x2<f32>,
+}
+
+@group(0) @binding(0) var<uniform> s : S;
+
+@compute @workgroup_size(1i)
+fn f() {
+  let x : mat2x3<f32> = transpose(s.m);
+}
+)";
+
+    auto got = Run<SimplifyPointers, TransposeRowMajor>(resolver::Resolve(b));
+
+    EXPECT_EQ(expect, str(got));
+}
+
+}  // namespace
+}  // namespace tint::spirv::reader