[ir] Fix Std140 for non-decomposed matrices

Recursively replacing instructions from a buffer that contains a
decomposed matrix cannot assume that all instructions are operating on
decomposed matrices, since there may be other matrices in the same
struct that were not decomposed.

Fixed: tint:2100
Change-Id: Ifc28636a39276d931592bcc1fe2026540d25c8c3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/160742
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/transform/std140.cc b/src/tint/lang/core/ir/transform/std140.cc
index 0f2b8f6..56897fd 100644
--- a/src/tint/lang/core/ir/transform/std140.cc
+++ b/src/tint/lang/core/ir/transform/std140.cc
@@ -330,12 +330,17 @@
                     load->Destroy();
                 },
                 [&](LoadVectorElement* load) {
-                    // We should have loaded the decomposed matrix, reconstructed it, so this is now
-                    // extracting from a value type.
-                    TINT_ASSERT(!replacement->Type()->Is<core::type::Pointer>());
-                    auto* access = b.Access(load->Result()->Type(), replacement, load->Index());
-                    load->Result()->ReplaceAllUsesWith(access->Result());
-                    load->Destroy();
+                    if (!replacement->Type()->Is<core::type::Pointer>()) {
+                        // We have loaded a decomposed matrix and reconstructed it, so this is now
+                        // extracting from a value type.
+                        auto* access = b.Access(load->Result()->Type(), replacement, load->Index());
+                        load->Result()->ReplaceAllUsesWith(access->Result());
+                        load->Destroy();
+                    } else {
+                        // There was no decomposed matrix on the path to this instruction so just
+                        // update the source operand.
+                        load->SetOperand(LoadVectorElement::kFromOperandOffset, replacement);
+                    }
                 },
                 [&](Let* let) {
                     // Let instructions just fold away.
diff --git a/src/tint/lang/core/ir/transform/std140_test.cc b/src/tint/lang/core/ir/transform/std140_test.cc
index 8415529..105f2cd 100644
--- a/src/tint/lang/core/ir/transform/std140_test.cc
+++ b/src/tint/lang/core/ir/transform/std140_test.cc
@@ -1483,6 +1483,283 @@
     EXPECT_EQ(expect, str());
 }
 
+TEST_F(IR_Std140Test, NotAllMatricesDecomposed) {
+    auto* mat4x4 = ty.mat4x4<f32>();
+    auto* mat3x2 = ty.mat3x2<f32>();
+    auto* structure = ty.Struct(mod.symbols.New("MyStruct"), {
+                                                                 {mod.symbols.New("a"), mat4x4},
+                                                                 {mod.symbols.New("b"), mat3x2},
+                                                             });
+    structure->SetStructFlag(core::type::kBlock);
+
+    auto* buffer = b.Var("buffer", ty.ptr(uniform, structure));
+    buffer->SetBindingPoint(0, 0);
+    mod.root_block->Append(buffer);
+
+    {
+        auto* func = b.Function("load_struct_a", mat4x4);
+        b.Append(func->Block(), [&] {
+            auto* load_struct = b.Load(buffer);
+            auto* extract_mat = b.Access(mat4x4, load_struct, 0_u);
+            b.Return(func, extract_mat);
+        });
+    }
+
+    {
+        auto* func = b.Function("load_struct_b", mat3x2);
+        b.Append(func->Block(), [&] {
+            auto* load_struct = b.Load(buffer);
+            auto* extract_mat = b.Access(mat3x2, load_struct, 1_u);
+            b.Return(func, extract_mat);
+        });
+    }
+
+    {
+        auto* func = b.Function("load_mat_a", ty.vec4<f32>());
+        b.Append(func->Block(), [&] {
+            auto* access_mat = b.Access(ty.ptr(uniform, mat4x4), buffer, 0_u);
+            auto* load_mat = b.Load(access_mat);
+            auto* extract_vec = b.Access(ty.vec4<f32>(), load_mat, 0_u);
+            b.Return(func, extract_vec);
+        });
+    }
+
+    {
+        auto* func = b.Function("load_mat_b", ty.vec2<f32>());
+        b.Append(func->Block(), [&] {
+            auto* access_mat = b.Access(ty.ptr(uniform, mat3x2), buffer, 1_u);
+            auto* load_mat = b.Load(access_mat);
+            auto* extract_vec = b.Access(ty.vec2<f32>(), load_mat, 0_u);
+            b.Return(func, extract_vec);
+        });
+    }
+
+    {
+        auto* func = b.Function("load_vec_a", ty.f32());
+        b.Append(func->Block(), [&] {
+            auto* access_vec = b.Access(ty.ptr(uniform, mat4x4->ColumnType()), buffer, 0_u, 1_u);
+            auto* load_vec = b.Load(access_vec);
+            auto* extract_el = b.Access(ty.f32(), load_vec, 1_u);
+            b.Return(func, extract_el);
+        });
+    }
+
+    {
+        auto* func = b.Function("load_vec_b", ty.f32());
+        b.Append(func->Block(), [&] {
+            auto* access_vec = b.Access(ty.ptr(uniform, mat3x2->ColumnType()), buffer, 1_u, 1_u);
+            auto* load_vec = b.Load(access_vec);
+            auto* extract_el = b.Access(ty.f32(), load_vec, 1_u);
+            b.Return(func, extract_el);
+        });
+    }
+
+    {
+        auto* func = b.Function("lve_a", ty.f32());
+        b.Append(func->Block(), [&] {
+            auto* access_vec = b.Access(ty.ptr(uniform, mat4x4->ColumnType()), buffer, 0_u, 1_u);
+            auto* lve = b.LoadVectorElement(access_vec, 1_u);
+            b.Return(func, lve);
+        });
+    }
+
+    {
+        auto* func = b.Function("lve_b", ty.f32());
+        b.Append(func->Block(), [&] {
+            auto* access_vec = b.Access(ty.ptr(uniform, mat3x2->ColumnType()), buffer, 1_u, 1_u);
+            auto* lve = b.LoadVectorElement(access_vec, 1_u);
+            b.Return(func, lve);
+        });
+    }
+
+    auto* src = R"(
+MyStruct = struct @align(16), @block {
+  a:mat4x4<f32> @offset(0)
+  b:mat3x2<f32> @offset(64)
+}
+
+%b1 = block {  # root
+  %buffer:ptr<uniform, MyStruct, read_write> = var @binding_point(0, 0)
+}
+
+%load_struct_a = func():mat4x4<f32> -> %b2 {
+  %b2 = block {
+    %3:MyStruct = load %buffer
+    %4:mat4x4<f32> = access %3, 0u
+    ret %4
+  }
+}
+%load_struct_b = func():mat3x2<f32> -> %b3 {
+  %b3 = block {
+    %6:MyStruct = load %buffer
+    %7:mat3x2<f32> = access %6, 1u
+    ret %7
+  }
+}
+%load_mat_a = func():vec4<f32> -> %b4 {
+  %b4 = block {
+    %9:ptr<uniform, mat4x4<f32>, read_write> = access %buffer, 0u
+    %10:mat4x4<f32> = load %9
+    %11:vec4<f32> = access %10, 0u
+    ret %11
+  }
+}
+%load_mat_b = func():vec2<f32> -> %b5 {
+  %b5 = block {
+    %13:ptr<uniform, mat3x2<f32>, read_write> = access %buffer, 1u
+    %14:mat3x2<f32> = load %13
+    %15:vec2<f32> = access %14, 0u
+    ret %15
+  }
+}
+%load_vec_a = func():f32 -> %b6 {
+  %b6 = block {
+    %17:ptr<uniform, vec4<f32>, read_write> = access %buffer, 0u, 1u
+    %18:vec4<f32> = load %17
+    %19:f32 = access %18, 1u
+    ret %19
+  }
+}
+%load_vec_b = func():f32 -> %b7 {
+  %b7 = block {
+    %21:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u, 1u
+    %22:vec2<f32> = load %21
+    %23:f32 = access %22, 1u
+    ret %23
+  }
+}
+%lve_a = func():f32 -> %b8 {
+  %b8 = block {
+    %25:ptr<uniform, vec4<f32>, read_write> = access %buffer, 0u, 1u
+    %26:f32 = load_vector_element %25, 1u
+    ret %26
+  }
+}
+%lve_b = func():f32 -> %b9 {
+  %b9 = block {
+    %28:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u, 1u
+    %29:f32 = load_vector_element %28, 1u
+    ret %29
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+MyStruct = struct @align(16), @block {
+  a:mat4x4<f32> @offset(0)
+  b:mat3x2<f32> @offset(64)
+}
+
+MyStruct_std140 = struct @align(16), @block {
+  a:mat4x4<f32> @offset(0)
+  b_col0:vec2<f32> @offset(64)
+  b_col1:vec2<f32> @offset(72)
+  b_col2:vec2<f32> @offset(80)
+}
+
+%b1 = block {  # root
+  %buffer:ptr<uniform, MyStruct_std140, read_write> = var @binding_point(0, 0)
+}
+
+%load_struct_a = func():mat4x4<f32> -> %b2 {
+  %b2 = block {
+    %3:MyStruct_std140 = load %buffer
+    %4:MyStruct = call %convert_MyStruct, %3
+    %6:mat4x4<f32> = access %4, 0u
+    ret %6
+  }
+}
+%load_struct_b = func():mat3x2<f32> -> %b3 {
+  %b3 = block {
+    %8:MyStruct_std140 = load %buffer
+    %9:MyStruct = call %convert_MyStruct, %8
+    %10:mat3x2<f32> = access %9, 1u
+    ret %10
+  }
+}
+%load_mat_a = func():vec4<f32> -> %b4 {
+  %b4 = block {
+    %12:ptr<uniform, mat4x4<f32>, read_write> = access %buffer, 0u
+    %13:mat4x4<f32> = load %12
+    %14:vec4<f32> = access %13, 0u
+    ret %14
+  }
+}
+%load_mat_b = func():vec2<f32> -> %b5 {
+  %b5 = block {
+    %16:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u
+    %17:vec2<f32> = load %16
+    %18:ptr<uniform, vec2<f32>, read_write> = access %buffer, 2u
+    %19:vec2<f32> = load %18
+    %20:ptr<uniform, vec2<f32>, read_write> = access %buffer, 3u
+    %21:vec2<f32> = load %20
+    %22:mat3x2<f32> = construct %17, %19, %21
+    %23:vec2<f32> = access %22, 0u
+    ret %23
+  }
+}
+%load_vec_a = func():f32 -> %b6 {
+  %b6 = block {
+    %25:ptr<uniform, vec4<f32>, read_write> = access %buffer, 0u, 1u
+    %26:vec4<f32> = load %25
+    %27:f32 = access %26, 1u
+    ret %27
+  }
+}
+%load_vec_b = func():f32 -> %b7 {
+  %b7 = block {
+    %29:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u
+    %30:vec2<f32> = load %29
+    %31:ptr<uniform, vec2<f32>, read_write> = access %buffer, 2u
+    %32:vec2<f32> = load %31
+    %33:ptr<uniform, vec2<f32>, read_write> = access %buffer, 3u
+    %34:vec2<f32> = load %33
+    %35:mat3x2<f32> = construct %30, %32, %34
+    %36:vec2<f32> = access %35, 1u
+    %37:f32 = access %36, 1u
+    ret %37
+  }
+}
+%lve_a = func():f32 -> %b8 {
+  %b8 = block {
+    %39:ptr<uniform, vec4<f32>, read_write> = access %buffer, 0u, 1u
+    %40:f32 = load_vector_element %39, 1u
+    ret %40
+  }
+}
+%lve_b = func():f32 -> %b9 {
+  %b9 = block {
+    %42:ptr<uniform, vec2<f32>, read_write> = access %buffer, 1u
+    %43:vec2<f32> = load %42
+    %44:ptr<uniform, vec2<f32>, read_write> = access %buffer, 2u
+    %45:vec2<f32> = load %44
+    %46:ptr<uniform, vec2<f32>, read_write> = access %buffer, 3u
+    %47:vec2<f32> = load %46
+    %48:mat3x2<f32> = construct %43, %45, %47
+    %49:vec2<f32> = access %48, 1u
+    %50:f32 = access %49, 1u
+    ret %50
+  }
+}
+%convert_MyStruct = func(%input:MyStruct_std140):MyStruct -> %b10 {
+  %b10 = block {
+    %52:mat4x4<f32> = access %input, 0u
+    %53:vec2<f32> = access %input, 1u
+    %54:vec2<f32> = access %input, 2u
+    %55:vec2<f32> = access %input, 3u
+    %56:mat3x2<f32> = construct %53, %54, %55
+    %57:MyStruct = construct %52, %56
+    ret %57
+  }
+}
+)";
+
+    Run(Std140);
+
+    EXPECT_EQ(expect, str());
+}
+
 TEST_F(IR_Std140Test, F16) {
     auto* structure =
         ty.Struct(mod.symbols.New("MyStruct"), {
diff --git a/test/tint/bug/tint/2100.wgsl b/test/tint/bug/tint/2100.wgsl
new file mode 100644
index 0000000..360ee9c
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl
@@ -0,0 +1,12 @@
+struct S {
+  matrix_view : mat4x4<f32>,
+  matrix_normal : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<uniform> buffer : S;
+
+@vertex
+fn main() -> @builtin(position) vec4f {
+  let x = buffer.matrix_view[0].z;
+  return vec4f(x, 0, 0, 1);
+}
diff --git a/test/tint/bug/tint/2100.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/2100.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..6b7e066
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.dxc.hlsl
@@ -0,0 +1,19 @@
+cbuffer cbuffer_buffer : register(b0) {
+  uint4 buffer[7];
+};
+
+struct tint_symbol {
+  float4 value : SV_Position;
+};
+
+float4 main_inner() {
+  const float x = asfloat(buffer[0].z);
+  return float4(x, 0.0f, 0.0f, 1.0f);
+}
+
+tint_symbol main() {
+  const float4 inner_result = main_inner();
+  tint_symbol wrapper_result = (tint_symbol)0;
+  wrapper_result.value = inner_result;
+  return wrapper_result;
+}
diff --git a/test/tint/bug/tint/2100.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/2100.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..6b7e066
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.fxc.hlsl
@@ -0,0 +1,19 @@
+cbuffer cbuffer_buffer : register(b0) {
+  uint4 buffer[7];
+};
+
+struct tint_symbol {
+  float4 value : SV_Position;
+};
+
+float4 main_inner() {
+  const float x = asfloat(buffer[0].z);
+  return float4(x, 0.0f, 0.0f, 1.0f);
+}
+
+tint_symbol main() {
+  const float4 inner_result = main_inner();
+  tint_symbol wrapper_result = (tint_symbol)0;
+  wrapper_result.value = inner_result;
+  return wrapper_result;
+}
diff --git a/test/tint/bug/tint/2100.wgsl.expected.glsl b/test/tint/bug/tint/2100.wgsl.expected.glsl
new file mode 100644
index 0000000..110fd04
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.glsl
@@ -0,0 +1,24 @@
+#version 310 es
+
+struct S {
+  mat4 matrix_view;
+  mat3 matrix_normal;
+};
+
+layout(binding = 0, std140) uniform tint_symbol_block_ubo {
+  S inner;
+} tint_symbol;
+
+vec4 tint_symbol_1() {
+  float x = tint_symbol.inner.matrix_view[0].z;
+  return vec4(x, 0.0f, 0.0f, 1.0f);
+}
+
+void main() {
+  gl_PointSize = 1.0;
+  vec4 inner_result = tint_symbol_1();
+  gl_Position = inner_result;
+  gl_Position.y = -(gl_Position.y);
+  gl_Position.z = ((2.0f * gl_Position.z) - gl_Position.w);
+  return;
+}
diff --git a/test/tint/bug/tint/2100.wgsl.expected.msl b/test/tint/bug/tint/2100.wgsl.expected.msl
new file mode 100644
index 0000000..89c4fe8
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.msl
@@ -0,0 +1,47 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+template<typename T, size_t N>
+struct tint_array {
+    const constant T& operator[](size_t i) const constant { return elements[i]; }
+    device T& operator[](size_t i) device { return elements[i]; }
+    const device T& operator[](size_t i) const device { return elements[i]; }
+    thread T& operator[](size_t i) thread { return elements[i]; }
+    const thread T& operator[](size_t i) const thread { return elements[i]; }
+    threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
+    const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
+    T elements[N];
+};
+
+struct tint_packed_vec3_f32_array_element {
+  /* 0x0000 */ packed_float3 elements;
+  /* 0x000c */ tint_array<int8_t, 4> tint_pad;
+};
+
+struct S_tint_packed_vec3 {
+  /* 0x0000 */ float4x4 matrix_view;
+  /* 0x0040 */ tint_array<tint_packed_vec3_f32_array_element, 3> matrix_normal;
+};
+
+struct S {
+  float4x4 matrix_view;
+  float3x3 matrix_normal;
+};
+
+struct tint_symbol_2 {
+  float4 value [[position]];
+};
+
+float4 tint_symbol_1_inner(const constant S_tint_packed_vec3* const tint_symbol_3) {
+  float const x = (*(tint_symbol_3)).matrix_view[0][2];
+  return float4(x, 0.0f, 0.0f, 1.0f);
+}
+
+vertex tint_symbol_2 tint_symbol_1(const constant S_tint_packed_vec3* tint_symbol_4 [[buffer(0)]]) {
+  float4 const inner_result = tint_symbol_1_inner(tint_symbol_4);
+  tint_symbol_2 wrapper_result = {};
+  wrapper_result.value = inner_result;
+  return wrapper_result;
+}
+
diff --git a/test/tint/bug/tint/2100.wgsl.expected.spvasm b/test/tint/bug/tint/2100.wgsl.expected.spvasm
new file mode 100644
index 0000000..c3f7f3f
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.spvasm
@@ -0,0 +1,70 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 34
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Vertex %main "main" %value %vertex_point_size
+               OpName %value "value"
+               OpName %vertex_point_size "vertex_point_size"
+               OpName %buffer_block "buffer_block"
+               OpMemberName %buffer_block 0 "inner"
+               OpName %S "S"
+               OpMemberName %S 0 "matrix_view"
+               OpMemberName %S 1 "matrix_normal"
+               OpName %buffer "buffer"
+               OpName %main_inner "main_inner"
+               OpName %main "main"
+               OpDecorate %value BuiltIn Position
+               OpDecorate %vertex_point_size BuiltIn PointSize
+               OpDecorate %buffer_block Block
+               OpMemberDecorate %buffer_block 0 Offset 0
+               OpMemberDecorate %S 0 Offset 0
+               OpMemberDecorate %S 0 ColMajor
+               OpMemberDecorate %S 0 MatrixStride 16
+               OpMemberDecorate %S 1 Offset 64
+               OpMemberDecorate %S 1 ColMajor
+               OpMemberDecorate %S 1 MatrixStride 16
+               OpDecorate %buffer NonWritable
+               OpDecorate %buffer DescriptorSet 0
+               OpDecorate %buffer Binding 0
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+          %5 = OpConstantNull %v4float
+      %value = OpVariable %_ptr_Output_v4float Output %5
+%_ptr_Output_float = OpTypePointer Output %float
+          %8 = OpConstantNull %float
+%vertex_point_size = OpVariable %_ptr_Output_float Output %8
+%mat4v4float = OpTypeMatrix %v4float 4
+    %v3float = OpTypeVector %float 3
+%mat3v3float = OpTypeMatrix %v3float 3
+          %S = OpTypeStruct %mat4v4float %mat3v3float
+%buffer_block = OpTypeStruct %S
+%_ptr_Uniform_buffer_block = OpTypePointer Uniform %buffer_block
+     %buffer = OpVariable %_ptr_Uniform_buffer_block Uniform
+         %16 = OpTypeFunction %v4float
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+        %int = OpTypeInt 32 1
+         %22 = OpConstantNull %int
+     %uint_2 = OpConstant %uint 2
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+    %float_1 = OpConstant %float 1
+       %void = OpTypeVoid
+         %29 = OpTypeFunction %void
+ %main_inner = OpFunction %v4float None %16
+         %18 = OpLabel
+         %25 = OpAccessChain %_ptr_Uniform_float %buffer %uint_0 %uint_0 %22 %uint_2
+         %26 = OpLoad %float %25
+         %28 = OpCompositeConstruct %v4float %26 %8 %8 %float_1
+               OpReturnValue %28
+               OpFunctionEnd
+       %main = OpFunction %void None %29
+         %32 = OpLabel
+         %33 = OpFunctionCall %v4float %main_inner
+               OpStore %value %33
+               OpStore %vertex_point_size %float_1
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/bug/tint/2100.wgsl.expected.wgsl b/test/tint/bug/tint/2100.wgsl.expected.wgsl
new file mode 100644
index 0000000..360ee9c
--- /dev/null
+++ b/test/tint/bug/tint/2100.wgsl.expected.wgsl
@@ -0,0 +1,12 @@
+struct S {
+  matrix_view : mat4x4<f32>,
+  matrix_normal : mat3x3<f32>,
+}
+
+@group(0) @binding(0) var<uniform> buffer : S;
+
+@vertex
+fn main() -> @builtin(position) vec4f {
+  let x = buffer.matrix_view[0].z;
+  return vec4f(x, 0, 0, 1);
+}