[tint][ir] Fix Std140 transform for arrays of matrices

Decompose matrices in arrays by creating a new struct. This matches the
logic of the old AST transform.

Bug: 338727551
Change-Id: I518dcf44d61ba15c13e599871bf57ec2df7c1794
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/187401
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/transform/std140.cc b/src/tint/lang/core/ir/transform/std140.cc
index 71b02ce..a61b8bc 100644
--- a/src/tint/lang/core/ir/transform/std140.cc
+++ b/src/tint/lang/core/ir/transform/std140.cc
@@ -27,14 +27,23 @@
 
 #include "src/tint/lang/core/ir/transform/std140.h"
 
+#include <cstdint>
 #include <utility>
 
+#include "src/tint/lang/core/address_space.h"
 #include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/function_param.h"
 #include "src/tint/lang/core/ir/module.h"
 #include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/core/type/array.h"
 #include "src/tint/lang/core/type/matrix.h"
+#include "src/tint/lang/core/type/memory_view.h"
+#include "src/tint/lang/core/type/pointer.h"
 #include "src/tint/lang/core/type/struct.h"
+#include "src/tint/lang/core/type/type.h"
+#include "src/tint/utils/containers/hashmap.h"
+#include "src/tint/utils/containers/vector.h"
+#include "src/tint/utils/text/string_stream.h"
 
 using namespace tint::core::fluent_types;     // NOLINT
 using namespace tint::core::number_suffixes;  // NOLINT
@@ -73,43 +82,51 @@
         }
 
         // Find uniform buffers that contain matrices that need to be decomposed.
-        Vector<Var*, 8> buffer_variables;
+        Vector<std::pair<Var*, const core::type::Type*>, 8> buffer_variables;
         for (auto inst : *ir.root_block) {
-            auto* var = inst->As<Var>();
-            if (!var || !var->Alive()) {
-                continue;
-            }
-            auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
-            if (!ptr || ptr->AddressSpace() != core::AddressSpace::kUniform) {
-                continue;
-            }
-            if (RewriteType(ptr->StoreType()) != ptr->StoreType()) {
-                buffer_variables.Push(var);
+            if (auto* var = inst->As<Var>()) {
+                auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
+                if (!ptr || ptr->AddressSpace() != core::AddressSpace::kUniform) {
+                    continue;
+                }
+                auto* store_type = RewriteType(ptr->StoreType());
+                if (store_type != ptr->StoreType()) {
+                    buffer_variables.Push(std::make_pair(var, store_type));
+                }
             }
         }
 
         // Now process the buffer variables, replacing them with new variables that have decomposed
         // matrices and updating all usages of the variables.
-        for (auto* var : buffer_variables) {
+        for (auto var_and_ty : buffer_variables) {
             // Create a new variable with the modified store type.
-            const auto& bp = var->BindingPoint();
-            auto* store_type = var->Result(0)->Type()->As<core::type::Pointer>()->StoreType();
-            auto* new_var = b.Var(ty.ptr(uniform, RewriteType(store_type)));
+            auto* old_var = var_and_ty.first;
+            auto* new_var = b.Var(ty.ptr(uniform, var_and_ty.second));
+            const auto& bp = old_var->BindingPoint();
             new_var->SetBindingPoint(bp->group, bp->binding);
-            if (auto name = ir.NameOf(var)) {
+            if (auto name = ir.NameOf(old_var)) {
                 ir.SetName(new_var->Result(0), name);
             }
 
-            // Replace every instruction that uses the original variable.
-            var->Result(0)->ForEachUse(
+            // Transform instructions that accessed the variable to use the decomposed var.
+            old_var->Result(0)->ForEachUse(
                 [&](Usage use) { Replace(use.instruction, new_var->Result(0)); });
 
             // Replace the original variable with the new variable.
-            var->ReplaceWith(new_var);
-            var->Destroy();
+            old_var->ReplaceWith(new_var);
+            old_var->Destroy();
         }
     }
 
+    /// @param type the type to check
+    /// @returns the matrix if @p type is a matrix that needs to be decomposed
+    static const core::type::Matrix* NeedsDecomposing(const core::type::Type* type) {
+        if (auto* mat = type->As<core::type::Matrix>(); mat && NeedsDecomposing(mat)) {
+            return mat;
+        }
+        return nullptr;
+    }
+
     /// @param mat the matrix type to check
     /// @returns true if @p mat needs to be decomposed
     static bool NeedsDecomposing(const core::type::Matrix* mat) {
@@ -125,10 +142,10 @@
     /// @param type the type to rewrite
     /// @returns the new type
     const core::type::Type* RewriteType(const core::type::Type* type) {
-        return rewritten_types.GetOrAdd(type, [&]() -> const core::type::Type* {
+        return rewritten_types.GetOrAdd(type, [&] {
             return tint::Switch(
                 type,
-                [&](const core::type::Array* arr) -> const core::type::Type* {
+                [&](const core::type::Array* arr) {
                     // Create a new array with element type potentially rewritten.
                     return ty.array(RewriteType(arr->ElemType()), arr->ConstantCount().value());
                 },
@@ -137,8 +154,7 @@
                     uint32_t member_index = 0;
                     Vector<const core::type::StructMember*, 4> new_members;
                     for (auto* member : str->Members()) {
-                        auto* mat = member->Type()->As<core::type::Matrix>();
-                        if (mat && NeedsDecomposing(mat)) {
+                        if (auto* mat = NeedsDecomposing(member->Type())) {
                             // Decompose these matrices into a separate member for each column.
                             member_index_map.Add(member, member_index);
                             auto* col = mat->ColumnType();
@@ -182,6 +198,32 @@
                     }
                     return new_str;
                 },
+                [&](const core::type::Matrix* mat) -> const core::type::Type* {
+                    if (!NeedsDecomposing(mat)) {
+                        return mat;
+                    }
+                    StringStream name;
+                    name << "mat" << mat->columns() << "x" << mat->rows() << "_"
+                         << mat->ColumnType()->type()->FriendlyName() << "_std140";
+                    Vector<core::type::StructMember*, 4> members;
+                    // Decompose these matrices into a separate member for each column.
+                    auto* col = mat->ColumnType();
+                    uint32_t offset = 0;
+                    for (uint32_t i = 0; i < mat->columns(); i++) {
+                        StringStream ss;
+                        ss << "col" << std::to_string(i);
+                        members.Push(ty.Get<core::type::StructMember>(
+                            sym.New(ss.str()), col, i, offset, col->Align(), col->Size(),
+                            core::type::StructMemberAttributes{}));
+                        offset += col->Align();
+                    }
+
+                    // Create a new struct with the rewritten members.
+                    return ty.Get<core::type::Struct>(
+                        sym.New(name.str()), std::move(members), col->Align(),
+                        col->Align() * mat->columns(),
+                        (col->Align() * (mat->columns() - 1)) + col->Size());
+                },
                 [&](Default) {
                     // This type cannot contain a matrix, so no changes needed.
                     return type;
@@ -189,19 +231,26 @@
         });
     }
 
-    /// Load a decomposed matrix from a structure.
+    /// Reconstructs a column-decomposed matrix.
     /// @param mat the matrix type
     /// @param root the root value being accessed into
-    /// @param indices the access indices that get to the first column of the decomposed matrix
+    /// @param indices the access indices that index the first column of the matrix.
     /// @returns the loaded matrix
-    Value* LoadMatrix(const core::type::Matrix* mat, Value* root, Vector<Value*, 4> indices) {
-        // Load each column vector from the struct and reconstruct the original matrix type.
+    Value* RebuildMatrix(const core::type::Matrix* mat, Value* root, VectorRef<Value*> indices) {
+        // Recombine each column vector from the struct and reconstruct the original matrix type.
+        bool is_ptr = root->Type()->Is<core::type::Pointer>();
+        Vector<Value*, 4> column_indices(std::move(indices));
         Vector<Value*, 4> args;
         auto first_column = indices.Back()->As<Constant>()->Value()->ValueAs<uint32_t>();
         for (uint32_t i = 0; i < mat->columns(); i++) {
-            indices.Back() = b.Constant(u32(first_column + i));
-            auto* access = b.Access(ty.ptr(uniform, mat->ColumnType()), root, indices);
-            args.Push(b.Load(access->Result(0))->Result(0));
+            column_indices.Back() = b.Constant(u32(first_column + i));
+            if (is_ptr) {
+                auto* access = b.Access(ty.ptr(uniform, mat->ColumnType()), root, column_indices);
+                args.Push(b.Load(access)->Result(0));
+            } else {
+                auto* access = b.Access(mat->ColumnType(), root, column_indices);
+                args.Push(access->Result(0));
+            }
         }
         return b.Construct(mat, std::move(args))->Result(0);
     }
@@ -228,16 +277,10 @@
                         uint32_t index = 0;
                         Vector<Value*, 4> args;
                         for (auto* member : str->Members()) {
-                            if (auto* mat = member->Type()->As<core::type::Matrix>();
-                                mat && NeedsDecomposing(mat)) {
-                                // Extract each decomposed column and reconstruct the matrix.
-                                Vector<Value*, 4> columns;
-                                for (uint32_t i = 0; i < mat->columns(); i++) {
-                                    auto* extract = b.Access(mat->ColumnType(), input, u32(index));
-                                    columns.Push(extract->Result(0));
-                                    index++;
-                                }
-                                args.Push(b.Construct(mat, std::move(columns))->Result(0));
+                            if (auto* mat = NeedsDecomposing(member->Type())) {
+                                args.Push(
+                                    RebuildMatrix(mat, input, Vector{b.Constant(u32(index))}));
+                                index += mat->columns();
                             } else {
                                 // Extract and convert the member.
                                 auto* type = input_str->Element(index);
@@ -268,6 +311,12 @@
                 });
                 return b.Load(new_arr)->Result(0);
             },
+            [&](const core::type::Matrix* mat) -> Value* {
+                if (!NeedsDecomposing(mat)) {
+                    return source;
+                }
+                return RebuildMatrix(mat, source, Vector{b.Constant(u32(0))});
+            },
             [&](Default) { return source; });
     }
 
@@ -279,28 +328,75 @@
             tint::Switch(
                 inst,  //
                 [&](Access* access) {
+                    auto* object_ty = access->Object()->Type()->As<core::type::MemoryView>();
+                    if (!object_ty || object_ty->AddressSpace() != core::AddressSpace::kUniform) {
+                        // Access to non-uniform memory views does not require transformation.
+                        return;
+                    }
+
+                    if (!replacement->Type()->Is<core::type::MemoryView>()) {
+                        // The replacement is a value, in which case the decomposed matrix has
+                        // already been reconstructed. In this situation the access only needs its
+                        // return type updating, and downstream instructions need updating.
+                        access->SetOperand(Access::kObjectOperandOffset, replacement);
+                        auto* result = access->Result(0);
+                        result->SetType(result->Type()->UnwrapPtrOrRef());
+                        result->ForEachUse([&](Usage use) { Replace(use.instruction, result); });
+                        return;
+                    }
+
                     // Modify the access indices to take decomposed matrices into account.
-                    auto* current_type = access->Object()->Type()->UnwrapPtr();
+                    auto* current_type = object_ty->StoreType();
                     Vector<Value*, 4> indices;
-                    for (auto idx : access->Indices()) {
-                        if (auto* str = current_type->As<core::type::Struct>()) {
+
+                    if (NeedsDecomposing(current_type)) {
+                        // Decomposed matrices are indexed using their first column vector
+                        indices.Push(b.Constant(0_u));
+                    }
+
+                    for (size_t i = 0, n = access->Indices().Length(); i < n; i++) {
+                        auto* idx = access->Indices()[i];
+
+                        if (auto* mat = NeedsDecomposing(current_type)) {
+                            // Access chain passes through decomposed matrix.
+                            if (auto* const_idx = idx->As<Constant>()) {
+                                // Column vector index is a constant.
+                                // Instead of loading the whole matrix, fold the access of the
+                                // matrix and the constant column index into an single access of
+                                // column vector member.
+                                auto* base_idx = indices.Back()->As<Constant>();
+                                indices.Back() =
+                                    b.Constant(u32(base_idx->Value()->ValueAs<uint32_t>() +
+                                                   const_idx->Value()->ValueAs<uint32_t>()));
+                                current_type = mat->ColumnType();
+                                i++;  // We've already consumed the column access
+                            } else {
+                                // Column vector index is dynamic.
+                                // Reconstruct the whole matrix and index that.
+                                replacement = RebuildMatrix(mat, replacement, std::move(indices));
+                                indices.Clear();
+                                indices.Push(idx);
+                                current_type = mat->ColumnType();
+                            }
+                        } else if (auto* str = current_type->As<core::type::Struct>()) {
+                            // Remap member index
                             uint32_t old_index = idx->As<Constant>()->Value()->ValueAs<uint32_t>();
                             uint32_t new_index = *member_index_map.Get(str->Members()[old_index]);
-                            indices.Push(b.Constant(u32(new_index)));
                             current_type = str->Element(old_index);
+                            indices.Push(b.Constant(u32(new_index)));
                         } else {
                             indices.Push(idx);
                             current_type = current_type->Elements().type;
+                            if (NeedsDecomposing(current_type)) {
+                                // Decomposed matrices are indexed using their first column vector
+                                indices.Push(b.Constant(0_u));
+                            }
                         }
+                    }
 
-                        // If we've hit a matrix that was decomposed, load the whole matrix.
-                        // Any additional accesses will extract columns instead of producing
-                        // pointers.
-                        if (auto* mat = current_type->As<core::type::Matrix>();
-                            mat && NeedsDecomposing(mat)) {
-                            replacement = LoadMatrix(mat, replacement, std::move(indices));
-                            indices.Clear();
-                        }
+                    if (auto* mat = NeedsDecomposing(current_type)) {
+                        replacement = RebuildMatrix(mat, replacement, std::move(indices));
+                        indices.Clear();
                     }
 
                     if (!indices.IsEmpty()) {
diff --git a/src/tint/lang/core/ir/transform/std140.h b/src/tint/lang/core/ir/transform/std140.h
index f931546..89852db 100644
--- a/src/tint/lang/core/ir/transform/std140.h
+++ b/src/tint/lang/core/ir/transform/std140.h
@@ -41,6 +41,8 @@
 
 /// Std140 is a transform that rewrites matrix types in the uniform address space to conform to
 /// GLSL's std140 layout rules.
+/// @note requires the DirectVariableAccess transform to have been run first to remove uniform
+/// pointer parameters.
 /// @param module the module to transform
 /// @returns success or failure
 Result<SuccessType> Std140(Module& module);
diff --git a/src/tint/lang/core/ir/transform/std140_fuzz.cc b/src/tint/lang/core/ir/transform/std140_fuzz.cc
index 86e5730..091def8 100644
--- a/src/tint/lang/core/ir/transform/std140_fuzz.cc
+++ b/src/tint/lang/core/ir/transform/std140_fuzz.cc
@@ -28,12 +28,31 @@
 #include "src/tint/lang/core/ir/transform/std140.h"
 
 #include "src/tint/cmd/fuzz/ir/fuzz.h"
+#include "src/tint/lang/core/address_space.h"
+#include "src/tint/lang/core/ir/module.h"
 #include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/core/type/pointer.h"
 
 namespace tint::core::ir::transform {
 namespace {
 
+bool CanRun(Module& module) {
+    for (auto& fn : module.functions) {
+        for (auto* param : fn->Params()) {
+            if (auto* ptr = param->Type()->As<core::type::Pointer>();
+                ptr && ptr->AddressSpace() == core::AddressSpace::kUniform) {
+                return false;  // Requires the DirectVariableAccess transform
+            }
+        }
+    }
+    return true;
+}
+
 void Std140Fuzzer(Module& module) {
+    if (!CanRun(module)) {
+        return;
+    }
+
     if (auto res = Std140(module); res != Success) {
         return;
     }
diff --git a/src/tint/lang/core/ir/transform/std140_test.cc b/src/tint/lang/core/ir/transform/std140_test.cc
index dd55108..28a67fa 100644
--- a/src/tint/lang/core/ir/transform/std140_test.cc
+++ b/src/tint/lang/core/ir/transform/std140_test.cc
@@ -29,6 +29,8 @@
 
 #include <utility>
 
+#include "src/tint/lang/core/fluent_types.h"
+#include "src/tint/lang/core/ir/load_vector_element.h"
 #include "src/tint/lang/core/ir/transform/helper_test.h"
 #include "src/tint/lang/core/type/array.h"
 #include "src/tint/lang/core/type/matrix.h"
@@ -148,8 +150,7 @@
     EXPECT_EQ(expect, str());
 }
 
-// Test that we do not decompose a mat2x2 that is used an array element type.
-TEST_F(IR_Std140Test, NoModify_Mat2x2_InsideArray) {
+TEST_F(IR_Std140Test, Load_Mat2x2f_InArray) {
     auto* mat = ty.mat2x2<f32>();
     auto* structure =
         ty.Struct(mod.symbols.New("MyStruct"), {
@@ -186,7 +187,35 @@
 )";
     EXPECT_EQ(src, str());
 
-    auto* expect = src;
+    auto* expect = R"(
+MyStruct = struct @align(8), @block {
+  arr:array<mat2x2<f32>, 4> @offset(0)
+}
+
+mat2x2_f32_std140 = struct @align(8) {
+  col0:vec2<f32> @offset(0)
+  col1:vec2<f32> @offset(8)
+}
+
+MyStruct_std140 = struct @align(8), @block {
+  arr:array<mat2x2_f32_std140, 4> @offset(0)
+}
+
+$B1: {  # root
+  %buffer:ptr<uniform, MyStruct_std140, read> = var @binding_point(0, 0)
+}
+
+%foo = func():mat2x2<f32> {
+  $B2: {
+    %3:ptr<uniform, vec2<f32>, read> = access %buffer, 0u, 2u, 0u
+    %4:vec2<f32> = load %3
+    %5:ptr<uniform, vec2<f32>, read> = access %buffer, 0u, 2u, 1u
+    %6:vec2<f32> = load %5
+    %7:mat2x2<f32> = construct %4, %6
+    ret %7
+  }
+}
+)";
 
     Run(Std140);
 
@@ -264,7 +293,7 @@
     EXPECT_EQ(expect, str());
 }
 
-TEST_F(IR_Std140Test, Mat3x2_LoadColumn) {
+TEST_F(IR_Std140Test, Mat3x2_LoadConstantColumn) {
     auto* mat = ty.mat3x2<f32>();
     auto* structure = ty.Struct(mod.symbols.New("MyStruct"), {
                                                                  {mod.symbols.New("a"), mat},
@@ -318,15 +347,83 @@
 
 %foo = func():vec2<f32> {
   $B2: {
-    %3:ptr<uniform, vec2<f32>, read> = access %buffer, 0u
+    %3:ptr<uniform, vec2<f32>, read> = access %buffer, 1u
     %4:vec2<f32> = load %3
-    %5:ptr<uniform, vec2<f32>, read> = access %buffer, 1u
-    %6:vec2<f32> = load %5
-    %7:ptr<uniform, vec2<f32>, read> = access %buffer, 2u
-    %8:vec2<f32> = load %7
-    %9:mat3x2<f32> = construct %4, %6, %8
-    %10:vec2<f32> = access %9, 1u
-    ret %10
+    ret %4
+  }
+}
+)";
+
+    Run(Std140);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_Std140Test, Mat3x2_LoadDynamicColumn) {
+    auto* mat = ty.mat3x2<f32>();
+    auto* structure = ty.Struct(mod.symbols.New("MyStruct"), {
+                                                                 {mod.symbols.New("a"), mat},
+                                                             });
+    structure->SetStructFlag(core::type::kBlock);
+
+    auto* buffer = b.Var("buffer", ty.ptr(uniform, structure));
+    buffer->SetBindingPoint(0, 0);
+    mod.root_block->Append(buffer);
+
+    auto* func = b.Function("foo", mat->ColumnType());
+    auto* column = b.FunctionParam<i32>("column");
+    func->AppendParam(column);
+    b.Append(func->Block(), [&] {
+        auto* access = b.Access(ty.ptr(uniform, mat->ColumnType()), buffer, 0_u, column);
+        auto* load = b.Load(access);
+        b.Return(func, load);
+    });
+
+    auto* src = R"(
+MyStruct = struct @align(8), @block {
+  a:mat3x2<f32> @offset(0)
+}
+
+$B1: {  # root
+  %buffer:ptr<uniform, MyStruct, read> = var @binding_point(0, 0)
+}
+
+%foo = func(%column:i32):vec2<f32> {
+  $B2: {
+    %4:ptr<uniform, vec2<f32>, read> = access %buffer, 0u, %column
+    %5:vec2<f32> = load %4
+    ret %5
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+MyStruct = struct @align(8), @block {
+  a:mat3x2<f32> @offset(0)
+}
+
+MyStruct_std140 = struct @align(8), @block {
+  a_col0:vec2<f32> @offset(0)
+  a_col1:vec2<f32> @offset(8)
+  a_col2:vec2<f32> @offset(16)
+}
+
+$B1: {  # root
+  %buffer:ptr<uniform, MyStruct_std140, read> = var @binding_point(0, 0)
+}
+
+%foo = func(%column:i32):vec2<f32> {
+  $B2: {
+    %4:ptr<uniform, vec2<f32>, read> = access %buffer, 0u
+    %5:vec2<f32> = load %4
+    %6:ptr<uniform, vec2<f32>, read> = access %buffer, 1u
+    %7:vec2<f32> = load %6
+    %8:ptr<uniform, vec2<f32>, read> = access %buffer, 2u
+    %9:vec2<f32> = load %8
+    %10:mat3x2<f32> = construct %5, %7, %9
+    %11:vec2<f32> = access %10, %column
+    ret %11
   }
 }
 )";
@@ -390,16 +487,9 @@
 
 %foo = func():f32 {
   $B2: {
-    %3:ptr<uniform, vec2<f32>, read> = access %buffer, 0u
-    %4:vec2<f32> = load %3
-    %5:ptr<uniform, vec2<f32>, read> = access %buffer, 1u
-    %6:vec2<f32> = load %5
-    %7:ptr<uniform, vec2<f32>, read> = access %buffer, 2u
-    %8:vec2<f32> = load %7
-    %9:mat3x2<f32> = construct %4, %6, %8
-    %10:vec2<f32> = access %9, 1u
-    %11:f32 = access %10, 1u
-    ret %11
+    %3:ptr<uniform, vec2<f32>, read> = access %buffer, 1u
+    %4:f32 = load_vector_element %3, 1u
+    ret %4
   }
 }
 )";
@@ -1709,48 +1799,35 @@
 }
 %load_vec_b = func():f32 {
   $B7: {
-    %29:ptr<uniform, vec2<f32>, read> = access %buffer, 1u
+    %29:ptr<uniform, vec2<f32>, read> = access %buffer, 2u
     %30:vec2<f32> = load %29
-    %31:ptr<uniform, vec2<f32>, read> = access %buffer, 2u
-    %32:vec2<f32> = load %31
-    %33:ptr<uniform, vec2<f32>, read> = access %buffer, 3u
-    %34:vec2<f32> = load %33
-    %35:mat3x2<f32> = construct %30, %32, %34
-    %36:vec2<f32> = access %35, 1u
-    %37:f32 = access %36, 1u
-    ret %37
+    %31:f32 = access %30, 1u
+    ret %31
   }
 }
 %lve_a = func():f32 {
   $B8: {
-    %39:ptr<uniform, vec4<f32>, read> = access %buffer, 0u, 1u
-    %40:f32 = load_vector_element %39, 1u
-    ret %40
+    %33:ptr<uniform, vec4<f32>, read> = access %buffer, 0u, 1u
+    %34:f32 = load_vector_element %33, 1u
+    ret %34
   }
 }
 %lve_b = func():f32 {
   $B9: {
-    %42:ptr<uniform, vec2<f32>, read> = access %buffer, 1u
-    %43:vec2<f32> = load %42
-    %44:ptr<uniform, vec2<f32>, read> = access %buffer, 2u
-    %45:vec2<f32> = load %44
-    %46:ptr<uniform, vec2<f32>, read> = access %buffer, 3u
-    %47:vec2<f32> = load %46
-    %48:mat3x2<f32> = construct %43, %45, %47
-    %49:vec2<f32> = access %48, 1u
-    %50:f32 = access %49, 1u
-    ret %50
+    %36:ptr<uniform, vec2<f32>, read> = access %buffer, 2u
+    %37:f32 = load_vector_element %36, 1u
+    ret %37
   }
 }
 %convert_MyStruct = func(%input:MyStruct_std140):MyStruct {
   $B10: {
-    %52:mat4x4<f32> = access %input, 0u
-    %53:vec2<f32> = access %input, 1u
-    %54:vec2<f32> = access %input, 2u
-    %55:vec2<f32> = access %input, 3u
-    %56:mat3x2<f32> = construct %53, %54, %55
-    %57:MyStruct = construct %52, %56
-    ret %57
+    %39:mat4x4<f32> = access %input, 0u
+    %40:vec2<f32> = access %input, 1u
+    %41:vec2<f32> = access %input, 2u
+    %42:vec2<f32> = access %input, 3u
+    %43:mat3x2<f32> = construct %40, %41, %42
+    %44:MyStruct = construct %39, %43
+    ret %44
   }
 }
 )";
@@ -1857,48 +1934,295 @@
     %14:vec4<f16> = load %13
     %15:mat4x4<f16> = construct %8, %10, %12, %14
     %mat:mat4x4<f16> = let %15
-    %17:ptr<uniform, vec3<f16>, read> = access %buffer, 4u
+    %17:ptr<uniform, vec3<f16>, read> = access %buffer, 5u
     %18:vec3<f16> = load %17
-    %19:ptr<uniform, vec3<f16>, read> = access %buffer, 5u
-    %20:vec3<f16> = load %19
-    %21:ptr<uniform, vec3<f16>, read> = access %buffer, 6u
-    %22:vec3<f16> = load %21
-    %23:ptr<uniform, vec3<f16>, read> = access %buffer, 7u
-    %24:vec3<f16> = load %23
-    %25:mat4x3<f16> = construct %18, %20, %22, %24
-    %26:vec3<f16> = access %25, 1u
-    %col:vec3<f16> = let %26
-    %28:ptr<uniform, vec4<f16>, read> = access %buffer, 2u
-    %29:vec4<f16> = load %28
-    %30:ptr<uniform, vec4<f16>, read> = access %buffer, 3u
-    %31:vec4<f16> = load %30
-    %32:mat2x4<f16> = construct %29, %31
-    %33:vec4<f16> = access %32, 0u
-    %34:f16 = access %33, 3u
-    %el:f16 = let %34
+    %col:vec3<f16> = let %18
+    %20:ptr<uniform, vec4<f16>, read> = access %buffer, 2u
+    %21:f16 = load_vector_element %20, 3u
+    %el:f16 = let %21
     ret
   }
 }
 %convert_MyStruct = func(%input:MyStruct_std140):MyStruct {
   $B3: {
-    %37:vec2<f16> = access %input, 0u
-    %38:vec2<f16> = access %input, 1u
-    %39:mat2x2<f16> = construct %37, %38
-    %40:vec4<f16> = access %input, 2u
-    %41:vec4<f16> = access %input, 3u
-    %42:mat2x4<f16> = construct %40, %41
-    %43:vec3<f16> = access %input, 4u
-    %44:vec3<f16> = access %input, 5u
-    %45:vec3<f16> = access %input, 6u
-    %46:vec3<f16> = access %input, 7u
-    %47:mat4x3<f16> = construct %43, %44, %45, %46
-    %48:vec4<f16> = access %input, 8u
-    %49:vec4<f16> = access %input, 9u
-    %50:vec4<f16> = access %input, 10u
-    %51:vec4<f16> = access %input, 11u
-    %52:mat4x4<f16> = construct %48, %49, %50, %51
-    %53:MyStruct = construct %39, %42, %47, %52
-    ret %53
+    %24:vec2<f16> = access %input, 0u
+    %25:vec2<f16> = access %input, 1u
+    %26:mat2x2<f16> = construct %24, %25
+    %27:vec4<f16> = access %input, 2u
+    %28:vec4<f16> = access %input, 3u
+    %29:mat2x4<f16> = construct %27, %28
+    %30:vec3<f16> = access %input, 4u
+    %31:vec3<f16> = access %input, 5u
+    %32:vec3<f16> = access %input, 6u
+    %33:vec3<f16> = access %input, 7u
+    %34:mat4x3<f16> = construct %30, %31, %32, %33
+    %35:vec4<f16> = access %input, 8u
+    %36:vec4<f16> = access %input, 9u
+    %37:vec4<f16> = access %input, 10u
+    %38:vec4<f16> = access %input, 11u
+    %39:mat4x4<f16> = construct %35, %36, %37, %38
+    %40:MyStruct = construct %26, %29, %34, %39
+    ret %40
+  }
+}
+)";
+
+    Run(Std140);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_Std140Test, Mat3x3f_And_ArrayMat4x3f) {  // crbug.com/338727551
+    auto* s =
+        ty.Struct(mod.symbols.New("S"), {
+                                            {mod.symbols.New("a"), ty.mat3x3<f32>()},
+                                            {mod.symbols.New("b"), ty.array<mat4x3<f32>, 3>()},
+                                        });
+    s->SetStructFlag(core::type::kBlock);
+
+    auto* u = b.Var("u", ty.ptr(uniform, s));
+    u->SetBindingPoint(0, 0);
+    mod.root_block->Append(u);
+
+    auto* f = b.Function("F", ty.f32());
+    b.Append(f->Block(), [&] {
+        auto* p = b.Access<ptr<uniform, vec3<f32>, read>>(u, 1_u, 0_u, 0_u);
+        auto* x = b.LoadVectorElement(p, 0_u);
+        b.Return(f, x);
+    });
+
+    auto* src = R"(
+S = struct @align(16), @block {
+  a:mat3x3<f32> @offset(0)
+  b:array<mat4x3<f32>, 3> @offset(48)
+}
+
+$B1: {  # root
+  %u:ptr<uniform, S, read> = var @binding_point(0, 0)
+}
+
+%F = func():f32 {
+  $B2: {
+    %3:ptr<uniform, vec3<f32>, read> = access %u, 1u, 0u, 0u
+    %4:f32 = load_vector_element %3, 0u
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+S = struct @align(16), @block {
+  a:mat3x3<f32> @offset(0)
+  b:array<mat4x3<f32>, 3> @offset(48)
+}
+
+mat4x3_f32_std140 = struct @align(16) {
+  col0:vec3<f32> @offset(0)
+  col1:vec3<f32> @offset(16)
+  col2:vec3<f32> @offset(32)
+  col3:vec3<f32> @offset(48)
+}
+
+S_std140 = struct @align(16), @block {
+  a_col0:vec3<f32> @offset(0)
+  a_col1:vec3<f32> @offset(16)
+  a_col2:vec3<f32> @offset(32)
+  b:array<mat4x3_f32_std140, 3> @offset(48)
+}
+
+$B1: {  # root
+  %u:ptr<uniform, S_std140, read> = var @binding_point(0, 0)
+}
+
+%F = func():f32 {
+  $B2: {
+    %3:ptr<uniform, vec3<f32>, read> = access %u, 3u, 0u, 0u
+    %4:f32 = load_vector_element %3, 0u
+    ret %4
+  }
+}
+)";
+
+    Run(Std140);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_Std140Test, Mat3x3f_And_ArrayStructMat4x3f) {
+    auto* s1 =
+        ty.Struct(mod.symbols.New("S1"), {
+                                             {mod.symbols.New("c"), ty.mat3x3<f32>()},
+                                             {mod.symbols.New("d"), ty.array<mat4x3<f32>, 3>()},
+                                         });
+    auto* s2 = ty.Struct(mod.symbols.New("S2"), {
+                                                    {mod.symbols.New("a"), ty.mat3x3<f32>()},
+                                                    {mod.symbols.New("b"), s1},
+                                                });
+    s2->SetStructFlag(core::type::kBlock);
+
+    auto* u = b.Var("u", ty.ptr(uniform, s2));
+    u->SetBindingPoint(0, 0);
+    mod.root_block->Append(u);
+
+    auto* f = b.Function("F", ty.f32());
+    b.Append(f->Block(), [&] {
+        auto* p = b.Access<ptr<uniform, vec3<f32>, read>>(u, 1_u, 1_u, 0_u, 0_u);
+        auto* x = b.LoadVectorElement(p, 0_u);
+        b.Return(f, x);
+    });
+
+    auto* src = R"(
+S1 = struct @align(16) {
+  c:mat3x3<f32> @offset(0)
+  d:array<mat4x3<f32>, 3> @offset(48)
+}
+
+S2 = struct @align(16), @block {
+  a:mat3x3<f32> @offset(0)
+  b:S1 @offset(48)
+}
+
+$B1: {  # root
+  %u:ptr<uniform, S2, read> = var @binding_point(0, 0)
+}
+
+%F = func():f32 {
+  $B2: {
+    %3:ptr<uniform, vec3<f32>, read> = access %u, 1u, 1u, 0u, 0u
+    %4:f32 = load_vector_element %3, 0u
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+S1 = struct @align(16) {
+  c:mat3x3<f32> @offset(0)
+  d:array<mat4x3<f32>, 3> @offset(48)
+}
+
+S2 = struct @align(16), @block {
+  a:mat3x3<f32> @offset(0)
+  b:S1 @offset(48)
+}
+
+mat4x3_f32_std140 = struct @align(16) {
+  col0:vec3<f32> @offset(0)
+  col1:vec3<f32> @offset(16)
+  col2:vec3<f32> @offset(32)
+  col3:vec3<f32> @offset(48)
+}
+
+S1_std140 = struct @align(16) {
+  c_col0:vec3<f32> @offset(0)
+  c_col1:vec3<f32> @offset(16)
+  c_col2:vec3<f32> @offset(32)
+  d:array<mat4x3_f32_std140, 3> @offset(48)
+}
+
+S2_std140 = struct @align(16), @block {
+  a_col0:vec3<f32> @offset(0)
+  a_col1:vec3<f32> @offset(16)
+  a_col2:vec3<f32> @offset(32)
+  b:S1_std140 @offset(48)
+}
+
+$B1: {  # root
+  %u:ptr<uniform, S2_std140, read> = var @binding_point(0, 0)
+}
+
+%F = func():f32 {
+  $B2: {
+    %3:ptr<uniform, vec3<f32>, read> = access %u, 3u, 3u, 0u, 0u
+    %4:f32 = load_vector_element %3, 0u
+    ret %4
+  }
+}
+)";
+
+    Run(Std140);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_Std140Test, Mat3x3f_And_ArrayStructMat2x2f) {
+    auto* s1 = ty.Struct(mod.symbols.New("S1"), {
+                                                    {mod.symbols.New("c"), ty.mat2x2<f32>()},
+                                                });
+    auto* s2 = ty.Struct(mod.symbols.New("S2"), {
+                                                    {mod.symbols.New("a"), ty.mat3x3<f32>()},
+                                                    {mod.symbols.New("b"), s1},
+                                                });
+    s2->SetStructFlag(core::type::kBlock);
+
+    auto* u = b.Var("u", ty.ptr(uniform, s2));
+    u->SetBindingPoint(0, 0);
+    mod.root_block->Append(u);
+
+    auto* f = b.Function("F", ty.f32());
+    b.Append(f->Block(), [&] {
+        auto* p = b.Access<ptr<uniform, vec2<f32>, read>>(u, 1_u, 0_u, 0_u);
+        auto* x = b.LoadVectorElement(p, 0_u);
+        b.Return(f, x);
+    });
+
+    auto* src = R"(
+S1 = struct @align(8) {
+  c:mat2x2<f32> @offset(0)
+}
+
+S2 = struct @align(16), @block {
+  a:mat3x3<f32> @offset(0)
+  b:S1 @offset(48)
+}
+
+$B1: {  # root
+  %u:ptr<uniform, S2, read> = var @binding_point(0, 0)
+}
+
+%F = func():f32 {
+  $B2: {
+    %3:ptr<uniform, vec2<f32>, read> = access %u, 1u, 0u, 0u
+    %4:f32 = load_vector_element %3, 0u
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+S1 = struct @align(8) {
+  c:mat2x2<f32> @offset(0)
+}
+
+S2 = struct @align(16), @block {
+  a:mat3x3<f32> @offset(0)
+  b:S1 @offset(48)
+}
+
+S1_std140 = struct @align(8) {
+  c_col0:vec2<f32> @offset(0)
+  c_col1:vec2<f32> @offset(8)
+}
+
+S2_std140 = struct @align(16), @block {
+  a_col0:vec3<f32> @offset(0)
+  a_col1:vec3<f32> @offset(16)
+  a_col2:vec3<f32> @offset(32)
+  b:S1_std140 @offset(48)
+}
+
+$B1: {  # root
+  %u:ptr<uniform, S2_std140, read> = var @binding_point(0, 0)
+}
+
+%F = func():f32 {
+  $B2: {
+    %3:ptr<uniform, vec2<f32>, read> = access %u, 3u, 0u
+    %4:f32 = load_vector_element %3, 0u
+    ret %4
   }
 }
 )";