transform::VarForDynamicIndex: Operate on matrices

Much like arrays, the SPIR-V writer cannot cope with dynamic indexing of matrices.

Fixed: tint:825
Change-Id: Ia111f15e0cf6fbd441861a4b3455a33b82b692ab
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51781
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/transform/var_for_dynamic_index.cc b/src/transform/var_for_dynamic_index.cc
index e9af1e7..ca24645 100644
--- a/src/transform/var_for_dynamic_index.cc
+++ b/src/transform/var_for_dynamic_index.cc
@@ -37,7 +37,7 @@
     if (auto* access_expr = node->As<ast::ArrayAccessorExpression>()) {
       // Found an array accessor expression
       auto* index_expr = access_expr->idx_expr();
-      auto* array_expr = access_expr->array();
+      auto* indexed_expr = access_expr->array();
 
       if (index_expr->Is<ast::ScalarConstructorExpression>()) {
         // Index expression is a literal value. As this isn't a dynamic index,
@@ -45,20 +45,20 @@
         continue;
       }
 
-      auto* array = ctx.src->Sem().Get(array_expr);
-      if (!array->Type()->Is<sem::Array>()) {
-        // This transform currently only cares about arrays.
+      auto* indexed = ctx.src->Sem().Get(indexed_expr);
+      if (!indexed->Type()->IsAnyOf<sem::Array, sem::Matrix>()) {
+        // This transform currently only cares about array and matrices.
         continue;
       }
 
-      auto* stmt = array->Stmt();   // Statement that owns the expression
-      auto* block = stmt->Block();  // Block that owns the statement
+      auto* stmt = indexed->Stmt();  // Statement that owns the expression
+      auto* block = stmt->Block();   // Block that owns the statement
 
       // Construct a `var` declaration to hold the value in memory.
-      auto* ty = CreateASTTypeFor(&ctx, array->Type());
-      auto var_name = ctx.dst->Symbols().New("var_for_array");
+      auto* ty = CreateASTTypeFor(&ctx, indexed->Type());
+      auto var_name = ctx.dst->Symbols().New("var_for_index");
       auto* var_decl = ctx.dst->Decl(ctx.dst->Var(
-          var_name, ty, ast::StorageClass::kNone, ctx.Clone(array_expr)));
+          var_name, ty, ast::StorageClass::kNone, ctx.Clone(indexed_expr)));
 
       // Insert the `var` declaration before the statement that performs the
       // indexing. Note that for indexing chains, AST node ordering guarantees
@@ -67,7 +67,7 @@
                        var_decl);
 
       // Replace the original index expression with the new `var`.
-      ctx.Replace(array_expr, ctx.dst->Expr(var_name));
+      ctx.Replace(indexed_expr, ctx.dst->Expr(var_name));
     }
   }
 
diff --git a/src/transform/var_for_dynamic_index.h b/src/transform/var_for_dynamic_index.h
index 7d928c6..86fa30a 100644
--- a/src/transform/var_for_dynamic_index.h
+++ b/src/transform/var_for_dynamic_index.h
@@ -23,10 +23,10 @@
 namespace tint {
 namespace transform {
 
-/// A transform that extracts array values that are dynamically indexed to a
-/// temporary `var` local before performing the index. This transform is used by
-/// the SPIR-V writer for dynamically indexing arrays, as there is no SPIR-V
-/// instruction that can dynamically index a non-pointer composite.
+/// A transform that extracts array and matrix values that are dynamically
+/// indexed to a temporary `var` local before performing the index. This
+/// transform is used by the SPIR-V writer as there is no SPIR-V instruction
+/// that can dynamically index a non-pointer composite.
 class VarForDynamicIndex : public Transform {
  public:
   /// Constructor
diff --git a/src/transform/var_for_dynamic_index_test.cc b/src/transform/var_for_dynamic_index_test.cc
index 4626dea..29dd84e 100644
--- a/src/transform/var_for_dynamic_index_test.cc
+++ b/src/transform/var_for_dynamic_index_test.cc
@@ -44,8 +44,8 @@
 fn f() {
   var i : i32;
   let p : array<i32, 4> = array<i32, 4>(1, 2, 3, 4);
-  var var_for_array : array<i32, 4> = p;
-  let x : i32 = var_for_array[i];
+  var var_for_index : array<i32, 4> = p;
+  let x : i32 = var_for_index[i];
 }
 )";
 
@@ -67,7 +67,7 @@
   // TODO(bclayton): Optimize this case:
   // This output is not as efficient as it could be.
   // We only actually need to hoist the inner-most array to a `var`
-  // (`var_for_array`), as later indexing operations will be working with
+  // (`var_for_index`), as later indexing operations will be working with
   // references, not values.
 
   auto* expect = R"(
@@ -75,9 +75,9 @@
   var i : i32;
   var j : i32;
   let p : array<array<i32, 2>, 2> = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
-  var var_for_array : array<array<i32, 2>, 2> = p;
-  var var_for_array_1 : array<i32, 2> = var_for_array[i];
-  let x : i32 = var_for_array_1[j];
+  var var_for_index : array<array<i32, 2>, 2> = p;
+  var var_for_index_1 : array<i32, 2> = var_for_index[i];
+  let x : i32 = var_for_index_1[j];
 }
 )";
 
diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc
index 7e1cb1d..f1afe46 100644
--- a/src/writer/spirv/builder_accessor_expression_test.cc
+++ b/src/writer/spirv/builder_accessor_expression_test.cc
@@ -22,9 +22,9 @@
 
 using BuilderTest = TestHelper;
 
-TEST_F(BuilderTest, ArrayAccessor) {
-  // vec3<f32> ary;
-  // ary[1]  -> ptr<f32>
+TEST_F(BuilderTest, ArrayAccessor_VectorRef_Literal) {
+  // var ary : vec3<f32>;
+  // ary[1]  -> ref<f32>
 
   auto* var = Var("ary", ty.vec3<f32>());
 
@@ -57,10 +57,10 @@
 )");
 }
 
-TEST_F(BuilderTest, Accessor_Array_LoadIndex) {
-  // ary : vec3<f32>;
-  // idx : i32;
-  // ary[idx]  -> ptr<f32>
+TEST_F(BuilderTest, ArrayAccessor_VectorRef_Dynamic) {
+  // var ary : vec3<f32>;
+  // var idx : i32;
+  // ary[idx]  -> ref<f32>
 
   auto* var = Var("ary", ty.vec3<f32>());
   auto* idx = Var("idx", ty.i32());
@@ -98,9 +98,9 @@
 )");
 }
 
-TEST_F(BuilderTest, ArrayAccessor_Dynamic) {
-  // vec3<f32> ary;
-  // ary[1 + 2]  -> ptr<f32>
+TEST_F(BuilderTest, ArrayAccessor_VectorRef_Dynamic2) {
+  // var ary : vec3<f32>;
+  // ary[1 + 2]  -> ref<f32>
 
   auto* var = Var("ary", ty.vec3<f32>());
 
@@ -134,10 +134,10 @@
 )");
 }
 
-TEST_F(BuilderTest, ArrayAccessor_MultiLevel) {
+TEST_F(BuilderTest, ArrayAccessor_ArrayRef_MultiLevel) {
   auto* ary4 = ty.array(ty.vec3<f32>(), 4);
 
-  // ary = array<vec3<f32>, 4>
+  // var ary : array<vec3<f32>, 4>
   // ary[3][2];
 
   auto* var = Var("ary", ary4);
@@ -172,7 +172,7 @@
 )");
 }
 
-TEST_F(BuilderTest, Accessor_ArrayWithSwizzle) {
+TEST_F(BuilderTest, ArrayAccessor_ArrayRef_ArrayWithSwizzle) {
   auto* ary4 = ty.array(ty.vec3<f32>(), 4);
 
   // var a : array<vec3<f32>, 4>;
@@ -680,7 +680,7 @@
 )");
 }
 
-TEST_F(BuilderTest, Accessor_Mixed_ArrayAndMember) {
+TEST_F(BuilderTest, ArrayAccessor_Mixed_ArrayAndMember) {
   // type C = struct {
   //   baz : vec3<f32>
   // }
@@ -747,7 +747,7 @@
 )");
 }
 
-TEST_F(BuilderTest, Accessor_Array_Of_Vec) {
+TEST_F(BuilderTest, ArrayAccessor_Of_Vec) {
   // let pos : array<vec2<f32>, 3> = array<vec2<f32>, 3>(
   //   vec2<f32>(0.0, 0.5),
   //   vec2<f32>(-0.5, -0.5),
@@ -790,7 +790,7 @@
   Validate(b);
 }
 
-TEST_F(BuilderTest, Accessor_Array_Of_Array_Of_f32) {
+TEST_F(BuilderTest, ArrayAccessor_Of_Array_Of_f32) {
   // let pos : array<array<f32, 2>, 3> = array<vec2<f32, 2>, 3>(
   //   array<f32, 2>(0.0, 0.5),
   //   array<f32, 2>(-0.5, -0.5),
@@ -835,7 +835,7 @@
   Validate(b);
 }
 
-TEST_F(BuilderTest, Accessor_Const_Vec) {
+TEST_F(BuilderTest, ArrayAccessor_Vec_Literal) {
   // let pos : vec2<f32> = vec2<f32>(0.0, 0.5);
   // pos[1]
 
@@ -864,7 +864,7 @@
 )");
 }
 
-TEST_F(BuilderTest, Accessor_Const_Vec_Dynamic) {
+TEST_F(BuilderTest, ArrayAccessor_Vec_Dynamic) {
   // let pos : vec2<f32> = vec2<f32>(0.0, 0.5);
   // idx : i32
   // pos[idx]
@@ -900,7 +900,7 @@
 )");
 }
 
-TEST_F(BuilderTest, Accessor_Array_NonPointer) {
+TEST_F(BuilderTest, ArrayAccessor_Array_Literal) {
   // let a : array<f32, 3>;
   // a[2]
 
@@ -934,7 +934,7 @@
   Validate(b);
 }
 
-TEST_F(BuilderTest, Accessor_Array_NonPointer_Dynamic) {
+TEST_F(BuilderTest, ArrayAccessor_Array_Dynamic) {
   // let a : array<f32, 3>;
   // idx : i32
   // a[idx]
@@ -982,6 +982,58 @@
   Validate(b);
 }
 
+TEST_F(BuilderTest, ArrayAccessor_Matrix_Dynamic) {
+  // let a : mat2x2<f32>(vec2<f32>(1., 2.), vec2<f32>(3., 4.));
+  // idx : i32
+  // a[idx]
+
+  auto* var =
+      Const("a", ty.mat2x2<f32>(),
+            Construct(ty.mat2x2<f32>(), Construct(ty.vec2<f32>(), 1.f, 2.f),
+                      Construct(ty.vec2<f32>(), 3.f, 4.f)));
+
+  auto* idx = Var("idx", ty.i32());
+  auto* expr = IndexAccessor("a", idx);
+
+  WrapInFunction(var, idx, expr);
+
+  spirv::Builder& b = SanitizeAndBuild();
+
+  ASSERT_TRUE(b.Build());
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%7 = OpTypeFloat 32
+%6 = OpTypeVector %7 2
+%5 = OpTypeMatrix %6 2
+%8 = OpConstant %7 1
+%9 = OpConstant %7 2
+%10 = OpConstantComposite %6 %8 %9
+%11 = OpConstant %7 3
+%12 = OpConstant %7 4
+%13 = OpConstantComposite %6 %11 %12
+%14 = OpConstantComposite %5 %10 %13
+%17 = OpTypeInt 32 1
+%16 = OpTypePointer Function %17
+%18 = OpConstantNull %17
+%20 = OpTypePointer Function %5
+%21 = OpConstantNull %5
+%23 = OpTypePointer Function %6
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
+            R"(%15 = OpVariable %16 Function %18
+%19 = OpVariable %20 Function %21
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpStore %19 %14
+%22 = OpLoad %17 %15
+%24 = OpAccessChain %23 %19 %22
+%25 = OpLoad %6 %24
+)");
+
+  Validate(b);
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer
diff --git a/test/bug/tint/824.wgsl.expected.spvasm b/test/bug/tint/824.wgsl.expected.spvasm
index bbdd3fd..394df83 100644
--- a/test/bug/tint/824.wgsl.expected.spvasm
+++ b/test/bug/tint/824.wgsl.expected.spvasm
@@ -17,9 +17,9 @@
                OpName %tint_symbol_5 "tint_symbol_5"
                OpName %tint_symbol_2 "tint_symbol_2"
                OpName %main "main"
-               OpName %var_for_array "var_for_array"
+               OpName %var_for_index "var_for_index"
                OpName %output "output"
-               OpName %var_for_array_1 "var_for_array_1"
+               OpName %var_for_index_1 "var_for_index_1"
                OpDecorate %tint_pointsize BuiltIn PointSize
                OpDecorate %tint_symbol BuiltIn VertexIndex
                OpDecorate %tint_symbol_1 BuiltIn InstanceIndex
@@ -88,21 +88,21 @@
                OpFunctionEnd
        %main = OpFunction %void None %22
          %24 = OpLabel
-%var_for_array = OpVariable %_ptr_Function__arr_v2float_uint_4 Function %40
+%var_for_index = OpVariable %_ptr_Function__arr_v2float_uint_4 Function %40
      %output = OpVariable %_ptr_Function_Output Function %48
-%var_for_array_1 = OpVariable %_ptr_Function__arr_v4float_uint_4 Function %62
+%var_for_index_1 = OpVariable %_ptr_Function__arr_v4float_uint_4 Function %62
                OpStore %tint_pointsize %float_1
-               OpStore %var_for_array %37
+               OpStore %var_for_index %37
          %41 = OpLoad %uint %tint_symbol_1
-         %44 = OpAccessChain %_ptr_Function_float %var_for_array %41 %uint_0
+         %44 = OpAccessChain %_ptr_Function_float %var_for_index %41 %uint_0
          %45 = OpLoad %float %44
          %50 = OpAccessChain %_ptr_Function_v4float %output %uint_0
          %52 = OpCompositeConstruct %v4float %float_0_5 %float_0_5 %45 %float_1
                OpStore %50 %52
-               OpStore %var_for_array_1 %59
+               OpStore %var_for_index_1 %59
          %64 = OpAccessChain %_ptr_Function_v4float %output %uint_1
          %65 = OpLoad %uint %tint_symbol_1
-         %66 = OpAccessChain %_ptr_Function_v4float %var_for_array_1 %65
+         %66 = OpAccessChain %_ptr_Function_v4float %var_for_index_1 %65
          %67 = OpLoad %v4float %66
                OpStore %64 %67
          %69 = OpLoad %Output %output
diff --git a/test/bug/tint/825.wgsl b/test/bug/tint/825.wgsl
new file mode 100644
index 0000000..92a0b7e
--- /dev/null
+++ b/test/bug/tint/825.wgsl
@@ -0,0 +1,6 @@
+fn f() {
+  var i : i32;
+  var j : i32;
+  let m : mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+  let f : f32 = m[i][j];
+}
diff --git a/test/bug/tint/825.wgsl.expected.hlsl b/test/bug/tint/825.wgsl.expected.hlsl
new file mode 100644
index 0000000..0acca6f
--- /dev/null
+++ b/test/bug/tint/825.wgsl.expected.hlsl
@@ -0,0 +1,12 @@
+void f() {
+  int i = 0;
+  int j = 0;
+  const float2x2 m = float2x2(float2(1.0f, 2.0f), float2(3.0f, 4.0f));
+  const float f = m[i][j];
+}
+
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
diff --git a/test/bug/tint/825.wgsl.expected.msl b/test/bug/tint/825.wgsl.expected.msl
new file mode 100644
index 0000000..5a8c360
--- /dev/null
+++ b/test/bug/tint/825.wgsl.expected.msl
@@ -0,0 +1,10 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void f() {
+  int i = 0;
+  int j = 0;
+  float2x2 const m = float2x2(float2(1.0f, 2.0f), float2(3.0f, 4.0f));
+  float const f = m[i][j];
+}
+
diff --git a/test/bug/tint/825.wgsl.expected.spvasm b/test/bug/tint/825.wgsl.expected.spvasm
new file mode 100644
index 0000000..d7996bf
--- /dev/null
+++ b/test/bug/tint/825.wgsl.expected.spvasm
@@ -0,0 +1,48 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 30
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+               OpExecutionMode %unused_entry_point LocalSize 1 1 1
+               OpName %unused_entry_point "unused_entry_point"
+               OpName %f "f"
+               OpName %i "i"
+               OpName %j "j"
+               OpName %var_for_index "var_for_index"
+       %void = OpTypeVoid
+          %1 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+         %10 = OpConstantNull %int
+      %float = OpTypeFloat 32
+    %v2float = OpTypeVector %float 2
+%mat2v2float = OpTypeMatrix %v2float 2
+    %float_1 = OpConstant %float 1
+    %float_2 = OpConstant %float 2
+         %17 = OpConstantComposite %v2float %float_1 %float_2
+    %float_3 = OpConstant %float 3
+    %float_4 = OpConstant %float 4
+         %20 = OpConstantComposite %v2float %float_3 %float_4
+         %21 = OpConstantComposite %mat2v2float %17 %20
+%_ptr_Function_mat2v2float = OpTypePointer Function %mat2v2float
+         %24 = OpConstantNull %mat2v2float
+%_ptr_Function_float = OpTypePointer Function %float
+%unused_entry_point = OpFunction %void None %1
+          %4 = OpLabel
+               OpReturn
+               OpFunctionEnd
+          %f = OpFunction %void None %1
+          %6 = OpLabel
+          %i = OpVariable %_ptr_Function_int Function %10
+          %j = OpVariable %_ptr_Function_int Function %10
+%var_for_index = OpVariable %_ptr_Function_mat2v2float Function %24
+               OpStore %var_for_index %21
+         %25 = OpLoad %int %i
+         %26 = OpLoad %int %j
+         %28 = OpAccessChain %_ptr_Function_float %var_for_index %25 %26
+         %29 = OpLoad %float %28
+               OpReturn
+               OpFunctionEnd
diff --git a/test/bug/tint/825.wgsl.expected.wgsl b/test/bug/tint/825.wgsl.expected.wgsl
new file mode 100644
index 0000000..92a0b7e
--- /dev/null
+++ b/test/bug/tint/825.wgsl.expected.wgsl
@@ -0,0 +1,6 @@
+fn f() {
+  var i : i32;
+  var j : i32;
+  let m : mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
+  let f : f32 = m[i][j];
+}
diff --git a/test/samples/triangle.wgsl.expected.spvasm b/test/samples/triangle.wgsl.expected.spvasm
index c6f874d..5c9ada3 100644
--- a/test/samples/triangle.wgsl.expected.spvasm
+++ b/test/samples/triangle.wgsl.expected.spvasm
@@ -16,7 +16,7 @@
                OpName %tint_symbol_3 "tint_symbol_3"
                OpName %tint_symbol_1 "tint_symbol_1"
                OpName %vtx_main "vtx_main"
-               OpName %var_for_array "var_for_array"
+               OpName %var_for_index "var_for_index"
                OpName %tint_symbol_6 "tint_symbol_6"
                OpName %tint_symbol_4 "tint_symbol_4"
                OpName %frag_main "frag_main"
@@ -64,11 +64,11 @@
                OpFunctionEnd
    %vtx_main = OpFunction %void None %29
          %31 = OpLabel
-%var_for_array = OpVariable %_ptr_Function__arr_v2float_uint_3 Function %35
+%var_for_index = OpVariable %_ptr_Function__arr_v2float_uint_3 Function %35
                OpStore %tint_pointsize %float_1
-               OpStore %var_for_array %pos
+               OpStore %var_for_index %pos
          %37 = OpLoad %int %tint_symbol
-         %39 = OpAccessChain %_ptr_Function_v2float %var_for_array %37
+         %39 = OpAccessChain %_ptr_Function_v2float %var_for_index %37
          %40 = OpLoad %v2float %39
          %41 = OpCompositeExtract %float %40 0
          %42 = OpCompositeExtract %float %40 1