writer/spirv: Fix dynamic array accessors

If the initial array accessor in a chain uses a non-literal index, use
the path that copies the source to a function variable, and then
perform a load from the OpAccessChain result if necessary.

Fixed: tint:426
Change-Id: Ie2f3f388170c02c1d6b73355f0b3bc49c3d3a4e5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49800
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 8a74ad5..7148a78 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -822,11 +822,8 @@
   auto* type = TypeOf(expr->idx_expr());
   idx_id = GenerateLoadIfNeeded(type, idx_id);
 
-  // If the source is a pointer we access chain into it. We also access chain
-  // into an array of non-scalar types.
-  if (info->source_type->Is<sem::Pointer>() ||
-      (info->source_type->Is<sem::ArrayType>() &&
-       !info->source_type->As<sem::ArrayType>()->type()->is_scalar())) {
+  // If the source is a pointer, we access chain into it.
+  if (info->source_type->Is<sem::Pointer>()) {
     info->access_chain_indices.push_back(idx_id);
     info->source_type = TypeOf(expr);
     return true;
@@ -1063,17 +1060,21 @@
   }
   info.source_type = TypeOf(source);
 
-  // If our initial access is into an array of non-scalar types, and that array
-  // is not a pointer, then we need to load that array into a variable in order
-  // to access chain into the array.
+  // If our initial access is into a non-pointer array, and either has a
+  // non-scalar element type or the accessor uses a non-literal index, then we
+  // need to load that array into a variable in order to access chain into it.
+  // TODO(jrprice): The non-scalar part shouldn't be necessary, but is tied to
+  // how the Resolver currently determines the type of these expression. This
+  // should be fixed when proper support for ptr/ref types is implemented.
   if (auto* array = accessors[0]->As<ast::ArrayAccessorExpression>()) {
-    auto* ary_res_type = TypeOf(array->array());
-
-    if (!ary_res_type->Is<sem::Pointer>() &&
-        (ary_res_type->Is<sem::ArrayType>() &&
-         !ary_res_type->As<sem::ArrayType>()->type()->is_scalar())) {
-      sem::Pointer ptr(ary_res_type, ast::StorageClass::kFunction);
-      auto result_type_id = GenerateTypeIfNeeded(&ptr);
+    auto* ary_res_type = TypeOf(array->array())->As<sem::ArrayType>();
+    if (ary_res_type &&
+        (!ary_res_type->type()->is_scalar() ||
+         !array->idx_expr()->Is<ast::ScalarConstructorExpression>())) {
+      // Wrap the source type in a pointer to function storage.
+      auto ptr =
+          builder_.ty.pointer(ary_res_type, ast::StorageClass::kFunction);
+      auto result_type_id = GenerateTypeIfNeeded(ptr);
       if (result_type_id == 0) {
         return 0;
       }
@@ -1094,6 +1095,7 @@
       }
 
       info.source_id = ary_result.to_i();
+      info.source_type = ptr;
     }
   }
 
@@ -1115,7 +1117,17 @@
   }
 
   if (!info.access_chain_indices.empty()) {
-    auto result_type_id = GenerateTypeIfNeeded(TypeOf(expr));
+    bool needs_load = false;
+    auto* ptr = TypeOf(expr);
+    if (!ptr->Is<sem::Pointer>()) {
+      // We are performing an access chain but the final result is not a
+      // pointer, so we need to perform a load to get it. This happens when we
+      // have to copy the source expression into a function variable.
+      ptr = builder_.ty.pointer(ptr, ast::StorageClass::kFunction);
+      needs_load = true;
+    }
+
+    auto result_type_id = GenerateTypeIfNeeded(ptr);
     if (result_type_id == 0) {
       return 0;
     }
@@ -1133,6 +1145,11 @@
       return false;
     }
     info.source_id = result_id;
+
+    // Load from the access chain result if required.
+    if (needs_load) {
+      info.source_id = GenerateLoadIfNeeded(ptr, result_id);
+    }
   }
 
   return info.source_id;
diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc
index d0c409b..80819a0 100644
--- a/src/writer/spirv/builder_accessor_expression_test.cc
+++ b/src/writer/spirv/builder_accessor_expression_test.cc
@@ -895,13 +895,54 @@
 )");
 }
 
-TEST_F(BuilderTest, DISABLED_Accessor_Array_NonPointer_Dynamic) {
+TEST_F(BuilderTest, Accessor_Array_NonPointer_Dynamic) {
   // let a : array<f32, 3>;
   // idx : i32
   // a[idx]
-  //
-  // This needs to copy the array to an OpVariable in the Function storage class
-  // and then access chain into it and load the result.
+
+  auto* var = GlobalConst("a", ty.array<f32, 3>(),
+                          Construct(ty.array<f32, 3>(), 0.0f, 0.5f, 1.0f));
+
+  auto* idx = Var("idx", ty.i32(), ast::StorageClass::kFunction);
+  auto* expr = IndexAccessor("a", idx);
+
+  ast::StatementList body;
+  body.push_back(WrapInStatement(idx));
+  body.push_back(WrapInStatement(expr));
+  WrapInFunction(body);
+
+  spirv::Builder& b = Build();
+
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
+  ASSERT_TRUE(b.GenerateFunctionVariable(idx)) << b.error();
+  EXPECT_EQ(b.GenerateAccessorExpression(expr), 19u) << b.error();
+
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
+%3 = OpTypeInt 32 0
+%4 = OpConstant %3 3
+%1 = OpTypeArray %2 %4
+%5 = OpConstant %2 0
+%6 = OpConstant %2 0.5
+%7 = OpConstant %2 1
+%8 = OpConstantComposite %1 %5 %6 %7
+%11 = OpTypeInt 32 1
+%10 = OpTypePointer Function %11
+%12 = OpConstantNull %11
+%13 = OpTypePointer Function %1
+%15 = OpConstantNull %1
+%17 = OpTypePointer Function %2
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
+            R"(%9 = OpVariable %10 Function %12
+%14 = OpVariable %13 Function %15
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpStore %14 %8
+%16 = OpLoad %11 %9
+%18 = OpAccessChain %17 %14 %16
+%19 = OpLoad %2 %18
+)");
 }
 
 }  // namespace