writer/msl: Avoid generating unnecessary pointers

When moving private and workgroup variables into the entry point,
generate pointers to pass as arguments to sub-functions on demand,
instead of upfront. This removes a bunch of unnecessary dereferences
for accesses inside the entry point, and one function variable.

Change-Id: I7d1aabdf14eae33b569b3316dfc0f9fbd288131e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/54300
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/transform/msl.cc b/src/transform/msl.cc
index c6a314b..f2ad35a 100644
--- a/src/transform/msl.cc
+++ b/src/transform/msl.cc
@@ -89,8 +89,7 @@
   // [[stage(compute)]]
   // fn main() {
   //   var<private> v : f32 = 2.0;
-  //   let v_ptr : ptr<private, f32> = &f32;
-  //   foo(v_ptr);
+  //   foo(&v);
   // }
   // ```
 
@@ -127,6 +126,7 @@
 
   for (auto* func_ast : functions_to_process) {
     auto* func_sem = ctx.src->Sem().Get(func_ast);
+    bool is_entry_point = func_ast->IsEntryPoint();
 
     // Map module-scope variables onto their function-scope replacement.
     std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
@@ -137,12 +137,12 @@
         continue;
       }
 
-      // This is the symbol for the pointer that replaces the module-scope var.
+      // This is the symbol for the variable that replaces the module-scope var.
       auto new_var_symbol = ctx.dst->Sym();
 
       auto* store_type = CreateASTTypeFor(&ctx, var->Type()->UnwrapRef());
 
-      if (func_ast->IsEntryPoint()) {
+      if (is_entry_point) {
         // For an entry point, redeclare the variable at function-scope.
         // Disable storage class validation on this variable.
         auto* disable_validation =
@@ -151,16 +151,10 @@
                 ast::DisabledValidation::kFunctionVarStorageClass);
         auto* constructor = ctx.Clone(var->Declaration()->constructor());
         auto* local_var =
-            ctx.dst->Var(ctx.dst->Sym(), store_type, var->StorageClass(),
+            ctx.dst->Var(new_var_symbol, store_type, var->StorageClass(),
                          constructor, ast::DecorationList{disable_validation});
         ctx.InsertBefore(func_ast->body()->statements(),
                          *func_ast->body()->begin(), ctx.dst->Decl(local_var));
-
-        // Now take the address of the variable.
-        auto* ptr = ctx.dst->Const(new_var_symbol, nullptr,
-                                   ctx.dst->AddressOf(local_var));
-        ctx.InsertBefore(func_ast->body()->statements(),
-                         *func_ast->body()->begin(), ctx.dst->Decl(ptr));
       } else {
         // For a regular function, redeclare the variable as a pointer function
         // parameter.
@@ -169,18 +163,22 @@
                        ctx.dst->Param(new_var_symbol, ptr_type));
       }
 
-      // Replace all uses of the module-scope variable with the pointer
-      // replacement (dereferenced).
+      // Replace all uses of the module-scope variable.
       for (auto* user : var->Users()) {
         if (user->Stmt()->Function() == func_ast) {
-          ctx.Replace(user->Declaration(), ctx.dst->Deref(new_var_symbol));
+          ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
+          if (!is_entry_point) {
+            // For non-entry points, dereference the pointer argument.
+            expr = ctx.dst->Deref(expr);
+          }
+          ctx.Replace(user->Declaration(), expr);
         }
       }
 
       var_to_symbol[var] = new_var_symbol;
     }
 
-    // Pass the pointers through to any functions that need them.
+    // Pass the variables as pointers to any functions that need them.
     for (auto* call : calls_to_replace[func_ast]) {
       auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
       auto* target_sem = ctx.src->Sem().Get(target);
@@ -189,8 +187,12 @@
       for (auto* target_var : target_sem->ReferencedModuleVariables()) {
         if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
             target_var->StorageClass() == ast::StorageClass::kWorkgroup) {
-          ctx.InsertBack(call->params(),
-                         ctx.dst->Expr(var_to_symbol[target_var]));
+          ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
+          if (is_entry_point) {
+            // For entry points, pass the address of the variable.
+            arg = ctx.dst->AddressOf(arg);
+          }
+          ctx.InsertBack(call->params(), arg);
         }
       }
     }
diff --git a/src/transform/msl_test.cc b/src/transform/msl_test.cc
index 70f50d8..d070679 100644
--- a/src/transform/msl_test.cc
+++ b/src/transform/msl_test.cc
@@ -36,11 +36,9 @@
   auto* expect = R"(
 [[stage(compute)]]
 fn main() {
-  [[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_1 : f32;
-  let tint_symbol = &(tint_symbol_1);
-  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_3 : f32;
-  let tint_symbol_2 = &(tint_symbol_3);
-  *(tint_symbol) = *(tint_symbol_2);
+  [[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol : f32;
+  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32;
+  tint_symbol = tint_symbol_1;
 }
 )";
 
@@ -91,11 +89,9 @@
 
 [[stage(compute)]]
 fn main() {
-  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_5 : f32;
-  let tint_symbol_4 = &(tint_symbol_5);
-  [[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_7 : f32;
-  let tint_symbol_6 = &(tint_symbol_7);
-  foo(1.0, tint_symbol_4, tint_symbol_6);
+  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_4 : f32;
+  [[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_5 : f32;
+  foo(1.0, &(tint_symbol_4), &(tint_symbol_5));
 }
 )";
 
@@ -118,11 +114,9 @@
   auto* expect = R"(
 [[stage(compute)]]
 fn main() {
-  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32 = 1.0;
-  let tint_symbol = &(tint_symbol_1);
-  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_3 : f32 = f32();
-  let tint_symbol_2 = &(tint_symbol_3);
-  let x : f32 = (*(tint_symbol) + *(tint_symbol_2));
+  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol : f32 = 1.0;
+  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32 = f32();
+  let x : f32 = (tint_symbol + tint_symbol_1);
 }
 )";
 
@@ -148,12 +142,10 @@
   auto* expect = R"(
 [[stage(compute)]]
 fn main() {
-  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol_1 : f32;
-  let tint_symbol = &(tint_symbol_1);
-  [[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_3 : f32;
-  let tint_symbol_2 = &(tint_symbol_3);
-  let p_ptr : ptr<private, f32> = &(*(tint_symbol));
-  let w_ptr : ptr<workgroup, f32> = &(*(tint_symbol_2));
+  [[internal(disable_validation__function_var_storage_class)]] var<private> tint_symbol : f32;
+  [[internal(disable_validation__function_var_storage_class)]] var<workgroup> tint_symbol_1 : f32;
+  let p_ptr : ptr<private, f32> = &(tint_symbol);
+  let w_ptr : ptr<workgroup, f32> = &(tint_symbol_1);
   let x : f32 = (*(p_ptr) + *(w_ptr));
   *(p_ptr) = x;
 }
diff --git a/src/writer/msl/generator_impl_variable_decl_statement_test.cc b/src/writer/msl/generator_impl_variable_decl_statement_test.cc
index 68e2983..d3eea3e 100644
--- a/src/writer/msl/generator_impl_variable_decl_statement_test.cc
+++ b/src/writer/msl/generator_impl_variable_decl_statement_test.cc
@@ -120,7 +120,7 @@
   gen.increment_indent();
 
   ASSERT_TRUE(gen.Generate()) << gen.error();
-  EXPECT_THAT(gen.result(), HasSubstr("thread float tint_symbol_2 = 0.0f;\n"));
+  EXPECT_THAT(gen.result(), HasSubstr("thread float tint_symbol_1 = 0.0f;\n"));
 }
 
 TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_Private) {
@@ -133,7 +133,7 @@
 
   ASSERT_TRUE(gen.Generate()) << gen.error();
   EXPECT_THAT(gen.result(),
-              HasSubstr("thread float tint_symbol_2 = initializer;\n"));
+              HasSubstr("thread float tint_symbol_1 = initializer;\n"));
 }
 
 TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) {
@@ -147,7 +147,7 @@
 
   ASSERT_TRUE(gen.Generate()) << gen.error();
   EXPECT_THAT(gen.result(),
-              HasSubstr("threadgroup float tint_symbol_2 = 0.0f;\n"));
+              HasSubstr("threadgroup float tint_symbol_1 = 0.0f;\n"));
 }
 
 TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec) {
diff --git a/test/ptr_ref/load/global/i32.spvasm.expected.msl b/test/ptr_ref/load/global/i32.spvasm.expected.msl
index c59e28e..b12f809 100644
--- a/test/ptr_ref/load/global/i32.spvasm.expected.msl
+++ b/test/ptr_ref/load/global/i32.spvasm.expected.msl
@@ -2,9 +2,8 @@
 
 using namespace metal;
 kernel void tint_symbol() {
-  thread int tint_symbol_2 = 0;
-  thread int* const tint_symbol_1 = &(tint_symbol_2);
-  int const x_9 = *(tint_symbol_1);
+  thread int tint_symbol_1 = 0;
+  int const x_9 = tint_symbol_1;
   int const x_11 = (x_9 + 1);
   return;
 }
diff --git a/test/ptr_ref/load/global/i32.wgsl.expected.msl b/test/ptr_ref/load/global/i32.wgsl.expected.msl
index 9431488..d7d3c63 100644
--- a/test/ptr_ref/load/global/i32.wgsl.expected.msl
+++ b/test/ptr_ref/load/global/i32.wgsl.expected.msl
@@ -2,9 +2,8 @@
 
 using namespace metal;
 kernel void tint_symbol() {
-  thread int tint_symbol_2 = 0;
-  thread int* const tint_symbol_1 = &(tint_symbol_2);
-  int const i = *(tint_symbol_1);
+  thread int tint_symbol_1 = 0;
+  int const i = tint_symbol_1;
   int const use = (i + 1);
   return;
 }
diff --git a/test/ptr_ref/load/global/struct_field.spvasm.expected.msl b/test/ptr_ref/load/global/struct_field.spvasm.expected.msl
index e3352ce..e1c1520 100644
--- a/test/ptr_ref/load/global/struct_field.spvasm.expected.msl
+++ b/test/ptr_ref/load/global/struct_field.spvasm.expected.msl
@@ -6,10 +6,9 @@
 };
 
 kernel void tint_symbol() {
-  thread S tint_symbol_2 = {};
-  thread S* const tint_symbol_1 = &(tint_symbol_2);
+  thread S tint_symbol_1 = {};
   int i = 0;
-  int const x_15 = (*(tint_symbol_1)).i;
+  int const x_15 = tint_symbol_1.i;
   i = x_15;
   return;
 }
diff --git a/test/ptr_ref/load/global/struct_field.wgsl.expected.msl b/test/ptr_ref/load/global/struct_field.wgsl.expected.msl
index cca8c39..857a03e 100644
--- a/test/ptr_ref/load/global/struct_field.wgsl.expected.msl
+++ b/test/ptr_ref/load/global/struct_field.wgsl.expected.msl
@@ -6,9 +6,8 @@
 };
 
 kernel void tint_symbol() {
-  thread S tint_symbol_2 = {};
-  thread S* const tint_symbol_1 = &(tint_symbol_2);
-  int const i = (*(tint_symbol_1)).i;
+  thread S tint_symbol_1 = {};
+  int const i = tint_symbol_1.i;
   return;
 }
 
diff --git a/test/ptr_ref/load/local/ptr_private.wgsl.expected.msl b/test/ptr_ref/load/local/ptr_private.wgsl.expected.msl
index b41e142..c600e5b 100644
--- a/test/ptr_ref/load/local/ptr_private.wgsl.expected.msl
+++ b/test/ptr_ref/load/local/ptr_private.wgsl.expected.msl
@@ -2,9 +2,8 @@
 
 using namespace metal;
 kernel void tint_symbol() {
-  thread int tint_symbol_2 = 123;
-  thread int* const tint_symbol_1 = &(tint_symbol_2);
-  thread int* const p = &(*(tint_symbol_1));
+  thread int tint_symbol_1 = 123;
+  thread int* const p = &(tint_symbol_1);
   int const use = (*(p) + 1);
   return;
 }
diff --git a/test/ptr_ref/load/local/ptr_workgroup.wgsl.expected.msl b/test/ptr_ref/load/local/ptr_workgroup.wgsl.expected.msl
index 983ba40..1f4dfe8 100644
--- a/test/ptr_ref/load/local/ptr_workgroup.wgsl.expected.msl
+++ b/test/ptr_ref/load/local/ptr_workgroup.wgsl.expected.msl
@@ -2,10 +2,9 @@
 
 using namespace metal;
 kernel void tint_symbol() {
-  threadgroup int tint_symbol_2 = 0;
-  threadgroup int* const tint_symbol_1 = &(tint_symbol_2);
-  *(tint_symbol_1) = 123;
-  threadgroup int* const p = &(*(tint_symbol_1));
+  threadgroup int tint_symbol_1 = 0;
+  tint_symbol_1 = 123;
+  threadgroup int* const p = &(tint_symbol_1);
   int const use = (*(p) + 1);
   return;
 }
diff --git a/test/ptr_ref/store/global/i32.spvasm.expected.msl b/test/ptr_ref/store/global/i32.spvasm.expected.msl
index 92fbc80..7eaf2d6 100644
--- a/test/ptr_ref/store/global/i32.spvasm.expected.msl
+++ b/test/ptr_ref/store/global/i32.spvasm.expected.msl
@@ -2,10 +2,9 @@
 
 using namespace metal;
 kernel void tint_symbol() {
-  thread int tint_symbol_2 = 0;
-  thread int* const tint_symbol_1 = &(tint_symbol_2);
-  *(tint_symbol_1) = 123;
-  *(tint_symbol_1) = ((100 + 20) + 3);
+  thread int tint_symbol_1 = 0;
+  tint_symbol_1 = 123;
+  tint_symbol_1 = ((100 + 20) + 3);
   return;
 }
 
diff --git a/test/ptr_ref/store/global/i32.wgsl.expected.msl b/test/ptr_ref/store/global/i32.wgsl.expected.msl
index 92fbc80..7eaf2d6 100644
--- a/test/ptr_ref/store/global/i32.wgsl.expected.msl
+++ b/test/ptr_ref/store/global/i32.wgsl.expected.msl
@@ -2,10 +2,9 @@
 
 using namespace metal;
 kernel void tint_symbol() {
-  thread int tint_symbol_2 = 0;
-  thread int* const tint_symbol_1 = &(tint_symbol_2);
-  *(tint_symbol_1) = 123;
-  *(tint_symbol_1) = ((100 + 20) + 3);
+  thread int tint_symbol_1 = 0;
+  tint_symbol_1 = 123;
+  tint_symbol_1 = ((100 + 20) + 3);
   return;
 }
 
diff --git a/test/ptr_ref/store/global/struct_field.spvasm.expected.msl b/test/ptr_ref/store/global/struct_field.spvasm.expected.msl
index f66e919..243f1d6 100644
--- a/test/ptr_ref/store/global/struct_field.spvasm.expected.msl
+++ b/test/ptr_ref/store/global/struct_field.spvasm.expected.msl
@@ -6,9 +6,8 @@
 };
 
 kernel void tint_symbol() {
-  thread S tint_symbol_2 = {};
-  thread S* const tint_symbol_1 = &(tint_symbol_2);
-  (*(tint_symbol_1)).i = 5;
+  thread S tint_symbol_1 = {};
+  tint_symbol_1.i = 5;
   return;
 }
 
diff --git a/test/types/module_scope_var.wgsl.expected.msl b/test/types/module_scope_var.wgsl.expected.msl
index b857fe8..c3a49df 100644
--- a/test/types/module_scope_var.wgsl.expected.msl
+++ b/test/types/module_scope_var.wgsl.expected.msl
@@ -8,38 +8,28 @@
 };
 
 kernel void tint_symbol() {
-  thread bool tint_symbol_4 = false;
-  thread bool* const tint_symbol_3 = &(tint_symbol_4);
-  thread int tint_symbol_6 = 0;
-  thread int* const tint_symbol_5 = &(tint_symbol_6);
-  thread uint tint_symbol_8 = 0u;
-  thread uint* const tint_symbol_7 = &(tint_symbol_8);
-  thread float tint_symbol_10 = 0.0f;
-  thread float* const tint_symbol_9 = &(tint_symbol_10);
-  thread int2 tint_symbol_12 = 0;
-  thread int2* const tint_symbol_11 = &(tint_symbol_12);
-  thread uint3 tint_symbol_14 = 0u;
-  thread uint3* const tint_symbol_13 = &(tint_symbol_14);
-  thread float4 tint_symbol_16 = 0.0f;
-  thread float4* const tint_symbol_15 = &(tint_symbol_16);
-  thread float2x3 tint_symbol_18 = float2x3(0.0f);
-  thread float2x3* const tint_symbol_17 = &(tint_symbol_18);
-  thread tint_array_wrapper_0 tint_symbol_20 = {0.0f};
-  thread tint_array_wrapper_0* const tint_symbol_19 = &(tint_symbol_20);
-  thread S tint_symbol_22 = {};
-  thread S* const tint_symbol_21 = &(tint_symbol_22);
-  *(tint_symbol_3) = bool();
-  *(tint_symbol_5) = int();
-  *(tint_symbol_7) = uint();
-  *(tint_symbol_9) = float();
-  *(tint_symbol_11) = int2();
-  *(tint_symbol_13) = uint3();
-  *(tint_symbol_15) = float4();
-  *(tint_symbol_17) = float2x3();
+  thread bool tint_symbol_3 = false;
+  thread int tint_symbol_4 = 0;
+  thread uint tint_symbol_5 = 0u;
+  thread float tint_symbol_6 = 0.0f;
+  thread int2 tint_symbol_7 = 0;
+  thread uint3 tint_symbol_8 = 0u;
+  thread float4 tint_symbol_9 = 0.0f;
+  thread float2x3 tint_symbol_10 = float2x3(0.0f);
+  thread tint_array_wrapper_0 tint_symbol_11 = {0.0f};
+  thread S tint_symbol_12 = {};
+  tint_symbol_3 = bool();
+  tint_symbol_4 = int();
+  tint_symbol_5 = uint();
+  tint_symbol_6 = float();
+  tint_symbol_7 = int2();
+  tint_symbol_8 = uint3();
+  tint_symbol_9 = float4();
+  tint_symbol_10 = float2x3();
   tint_array_wrapper_0 const tint_symbol_1 = {};
-  *(tint_symbol_19) = tint_symbol_1;
+  tint_symbol_11 = tint_symbol_1;
   S const tint_symbol_2 = {};
-  *(tint_symbol_21) = tint_symbol_2;
+  tint_symbol_12 = tint_symbol_2;
   return;
 }
 
diff --git a/test/types/module_scope_var_initializers.wgsl.expected.msl b/test/types/module_scope_var_initializers.wgsl.expected.msl
index af7d470..6a49541 100644
--- a/test/types/module_scope_var_initializers.wgsl.expected.msl
+++ b/test/types/module_scope_var_initializers.wgsl.expected.msl
@@ -8,38 +8,28 @@
 };
 
 kernel void tint_symbol() {
-  thread bool tint_symbol_4 = bool();
-  thread bool* const tint_symbol_3 = &(tint_symbol_4);
-  thread int tint_symbol_6 = int();
-  thread int* const tint_symbol_5 = &(tint_symbol_6);
-  thread uint tint_symbol_8 = uint();
-  thread uint* const tint_symbol_7 = &(tint_symbol_8);
-  thread float tint_symbol_10 = float();
-  thread float* const tint_symbol_9 = &(tint_symbol_10);
-  thread int2 tint_symbol_12 = int2();
-  thread int2* const tint_symbol_11 = &(tint_symbol_12);
-  thread uint3 tint_symbol_14 = uint3();
-  thread uint3* const tint_symbol_13 = &(tint_symbol_14);
-  thread float4 tint_symbol_16 = float4();
-  thread float4* const tint_symbol_15 = &(tint_symbol_16);
-  thread float2x3 tint_symbol_18 = float2x3();
-  thread float2x3* const tint_symbol_17 = &(tint_symbol_18);
-  thread tint_array_wrapper_0 tint_symbol_20 = {};
-  thread tint_array_wrapper_0* const tint_symbol_19 = &(tint_symbol_20);
-  thread S tint_symbol_22 = {};
-  thread S* const tint_symbol_21 = &(tint_symbol_22);
-  *(tint_symbol_3) = bool();
-  *(tint_symbol_5) = int();
-  *(tint_symbol_7) = uint();
-  *(tint_symbol_9) = float();
-  *(tint_symbol_11) = int2();
-  *(tint_symbol_13) = uint3();
-  *(tint_symbol_15) = float4();
-  *(tint_symbol_17) = float2x3();
+  thread bool tint_symbol_3 = bool();
+  thread int tint_symbol_4 = int();
+  thread uint tint_symbol_5 = uint();
+  thread float tint_symbol_6 = float();
+  thread int2 tint_symbol_7 = int2();
+  thread uint3 tint_symbol_8 = uint3();
+  thread float4 tint_symbol_9 = float4();
+  thread float2x3 tint_symbol_10 = float2x3();
+  thread tint_array_wrapper_0 tint_symbol_11 = {};
+  thread S tint_symbol_12 = {};
+  tint_symbol_3 = bool();
+  tint_symbol_4 = int();
+  tint_symbol_5 = uint();
+  tint_symbol_6 = float();
+  tint_symbol_7 = int2();
+  tint_symbol_8 = uint3();
+  tint_symbol_9 = float4();
+  tint_symbol_10 = float2x3();
   tint_array_wrapper_0 const tint_symbol_1 = {};
-  *(tint_symbol_19) = tint_symbol_1;
+  tint_symbol_11 = tint_symbol_1;
   S const tint_symbol_2 = {};
-  *(tint_symbol_21) = tint_symbol_2;
+  tint_symbol_12 = tint_symbol_2;
   return;
 }
 
diff --git a/test/var/private.wgsl.expected.msl b/test/var/private.wgsl.expected.msl
index f17afeb..ed8cfa6 100644
--- a/test/var/private.wgsl.expected.msl
+++ b/test/var/private.wgsl.expected.msl
@@ -25,27 +25,23 @@
 }
 
 kernel void main1() {
-  thread int tint_symbol_7 = 0;
-  thread int* const tint_symbol_6 = &(tint_symbol_7);
-  *(tint_symbol_6) = 42;
-  uses_a(tint_symbol_6);
+  thread int tint_symbol_6 = 0;
+  tint_symbol_6 = 42;
+  uses_a(&(tint_symbol_6));
   return;
 }
 
 kernel void main2() {
-  thread int tint_symbol_9 = 0;
-  thread int* const tint_symbol_8 = &(tint_symbol_9);
-  *(tint_symbol_8) = 7;
-  uses_b(tint_symbol_8);
+  thread int tint_symbol_7 = 0;
+  tint_symbol_7 = 7;
+  uses_b(&(tint_symbol_7));
   return;
 }
 
 kernel void main3() {
-  thread int tint_symbol_11 = 0;
-  thread int* const tint_symbol_10 = &(tint_symbol_11);
-  thread int tint_symbol_13 = 0;
-  thread int* const tint_symbol_12 = &(tint_symbol_13);
-  outer(tint_symbol_10, tint_symbol_12);
+  thread int tint_symbol_8 = 0;
+  thread int tint_symbol_9 = 0;
+  outer(&(tint_symbol_8), &(tint_symbol_9));
   no_uses();
   return;
 }
diff --git a/test/var/workgroup.wgsl.expected.msl b/test/var/workgroup.wgsl.expected.msl
index 3c53df8..88cd0db 100644
--- a/test/var/workgroup.wgsl.expected.msl
+++ b/test/var/workgroup.wgsl.expected.msl
@@ -25,27 +25,23 @@
 }
 
 kernel void main1() {
-  threadgroup int tint_symbol_7 = 0;
-  threadgroup int* const tint_symbol_6 = &(tint_symbol_7);
-  *(tint_symbol_6) = 42;
-  uses_a(tint_symbol_6);
+  threadgroup int tint_symbol_6 = 0;
+  tint_symbol_6 = 42;
+  uses_a(&(tint_symbol_6));
   return;
 }
 
 kernel void main2() {
-  threadgroup int tint_symbol_9 = 0;
-  threadgroup int* const tint_symbol_8 = &(tint_symbol_9);
-  *(tint_symbol_8) = 7;
-  uses_b(tint_symbol_8);
+  threadgroup int tint_symbol_7 = 0;
+  tint_symbol_7 = 7;
+  uses_b(&(tint_symbol_7));
   return;
 }
 
 kernel void main3() {
-  threadgroup int tint_symbol_11 = 0;
-  threadgroup int* const tint_symbol_10 = &(tint_symbol_11);
-  threadgroup int tint_symbol_13 = 0;
-  threadgroup int* const tint_symbol_12 = &(tint_symbol_13);
-  outer(tint_symbol_10, tint_symbol_12);
+  threadgroup int tint_symbol_8 = 0;
+  threadgroup int tint_symbol_9 = 0;
+  outer(&(tint_symbol_8), &(tint_symbol_9));
   no_uses();
   return;
 }