HoistToDeclBefore: hoist to reference if expression is a reference type

Don't lose reference-ness of expression. This is necessary for
assignments via the hoisted variable, for example.

Bug: tint:1300
Change-Id: I8e633f20e50541bb70becc5069019f795ec11e01
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/80540
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/transform/utils/hoist_to_decl_before.cc b/src/transform/utils/hoist_to_decl_before.cc
index 2448506..aeff023 100644
--- a/src/transform/utils/hoist_to_decl_before.cc
+++ b/src/transform/utils/hoist_to_decl_before.cc
@@ -20,6 +20,8 @@
 #include "src/sem/block_statement.h"
 #include "src/sem/for_loop_statement.h"
 #include "src/sem/if_statement.h"
+#include "src/sem/reference_type.h"
+#include "src/sem/variable.h"
 #include "src/utils/reverse.h"
 
 namespace tint::transform {
@@ -265,10 +267,22 @@
                          const ast::Expression* expr,
                          bool as_const,
                          const char* decl_name = "") {
-    // Construct the let/var that holds the hoisted expr
     auto name = b.Symbols().New(decl_name);
-    auto* v = as_const ? b.Const(name, nullptr, ctx.Clone(expr))
-                       : b.Var(name, nullptr, ctx.Clone(expr));
+
+    auto* sem_expr = ctx.src->Sem().Get(expr);
+    bool is_ref =
+        sem_expr &&
+        !sem_expr->Is<sem::VariableUser>()  // Don't need to take a ref to a var
+        && sem_expr->Type()->Is<sem::Reference>();
+
+    auto* expr_clone = ctx.Clone(expr);
+    if (is_ref) {
+      expr_clone = b.AddressOf(expr_clone);
+    }
+
+    // Construct the let/var that holds the hoisted expr
+    auto* v = as_const ? b.Const(name, nullptr, expr_clone)
+                       : b.Var(name, nullptr, expr_clone);
     auto* decl = b.Decl(v);
 
     if (!InsertBefore(before_expr, decl)) {
@@ -276,7 +290,11 @@
     }
 
     // Replace the initializer expression with a reference to the let
-    ctx.Replace(expr, b.Expr(name));
+    const ast::Expression* new_expr = b.Expr(name);
+    if (is_ref) {
+      new_expr = b.Deref(new_expr);
+    }
+    ctx.Replace(expr, new_expr);
     return true;
   }
 
diff --git a/src/transform/utils/hoist_to_decl_before_test.cc b/src/transform/utils/hoist_to_decl_before_test.cc
index 5dddde4..8dd902d 100644
--- a/src/transform/utils/hoist_to_decl_before_test.cc
+++ b/src/transform/utils/hoist_to_decl_before_test.cc
@@ -217,5 +217,76 @@
   EXPECT_EQ(expect, str(cloned));
 }
 
+TEST_F(HoistToDeclBeforeTest, Array1D) {
+  // fn f() {
+  //     var a : array<i32, 10>;
+  //     var b = a[0];
+  // }
+  ProgramBuilder b;
+  auto* var1 = b.Decl(b.Var("a", b.ty.array<ProgramBuilder::i32, 10>()));
+  auto* expr = b.IndexAccessor("a", 0);
+  auto* var2 = b.Decl(b.Var("b", nullptr, expr));
+  b.Func("f", {}, b.ty.void_(), {var1, var2});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* sem_expr = ctx.src->Sem().Get(expr);
+  hoistToDeclBefore.Add(sem_expr, expr, true);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn f() {
+  var a : array<i32, 10>;
+  let tint_symbol = &(a[0]);
+  var b = *(tint_symbol);
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
+TEST_F(HoistToDeclBeforeTest, Array2D) {
+  // fn f() {
+  //     var a : array<array<i32, 10>, 10>;
+  //     var b = a[0][0];
+  // }
+  ProgramBuilder b;
+
+  auto* var1 =
+      b.Decl(b.Var("a", b.ty.array(b.ty.array<ProgramBuilder::i32, 10>(), 10)));
+  auto* expr = b.IndexAccessor(b.IndexAccessor("a", 0), 0);
+  auto* var2 = b.Decl(b.Var("b", nullptr, expr));
+  b.Func("f", {}, b.ty.void_(), {var1, var2});
+
+  Program original(std::move(b));
+  ProgramBuilder cloned_b;
+  CloneContext ctx(&cloned_b, &original);
+  std::cout << str(original) << std::endl;
+
+  HoistToDeclBefore hoistToDeclBefore(ctx);
+  auto* sem_expr = ctx.src->Sem().Get(expr);
+  hoistToDeclBefore.Add(sem_expr, expr, true);
+  hoistToDeclBefore.Apply();
+
+  ctx.Clone();
+  Program cloned(std::move(cloned_b));
+
+  auto* expect = R"(
+fn f() {
+  var a : array<array<i32, 10>, 10>;
+  let tint_symbol = &(a[0][0]);
+  var b = *(tint_symbol);
+}
+)";
+
+  EXPECT_EQ(expect, str(cloned));
+}
+
 }  // namespace
 }  // namespace tint::transform