[msl] Fix PackedVec3 for atomic builtins

Only unwrap pointers before load instructions, instead of before
calling `UpdateUsage()`, so that the `packed == unpacked` check at the
start of `UpdateUsage()` can correctly determine when the target type
is not in fact packed.

Fixed: 366314931
Change-Id: Ie8e3c0bf87a0afd0d13199227fc9687e57fe0809
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/206395
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/lang/msl/writer/raise/packed_vec3.cc b/src/tint/lang/msl/writer/raise/packed_vec3.cc
index 058bb03..aa059e6 100644
--- a/src/tint/lang/msl/writer/raise/packed_vec3.cc
+++ b/src/tint/lang/msl/writer/raise/packed_vec3.cc
@@ -291,12 +291,12 @@
                 auto* packed_result_type = RewriteType(unpacked_result_type);
                 let->Result(0)->SetType(packed_result_type);
                 let->Result(0)->ForEachUseSorted([&](core::ir::Usage let_use) {  //
-                    UpdateUsage(let_use, unpacked_result_type->UnwrapPtr(), packed_result_type);
+                    UpdateUsage(let_use, unpacked_result_type, packed_result_type);
                 });
             },
             [&](core::ir::Load* load) {
                 b.InsertAfter(load, [&] {
-                    auto* result = LoadPackedToUnpacked(unpacked_type, load->From());
+                    auto* result = LoadPackedToUnpacked(unpacked_type->UnwrapPtr(), load->From());
                     load->Result(0)->ReplaceAllUsesWith(result);
                 });
                 load->Destroy();
@@ -327,7 +327,7 @@
         // Rebuild the indices of the access instruction.
         // Walk through the intermediate types that the access chain will be traversing, and
         // check for packed vectors that would be wrapped in structures.
-        auto* obj_type = unpacked_type;
+        auto* obj_type = unpacked_type->UnwrapPtr();
         Vector<core::ir::Value*, 4> operands;
         operands.Push(access->Object());
         for (auto* idx : access->Indices()) {
@@ -354,7 +354,7 @@
         access->SetOperands(std::move(operands));
         access->Result(0)->SetType(packed_result_type);
         access->Result(0)->ForEachUseSorted([&](core::ir::Usage access_use) {  //
-            UpdateUsage(access_use, unpacked_result_type->UnwrapPtr(), packed_result_type);
+            UpdateUsage(access_use, unpacked_result_type, packed_result_type);
         });
     }
 
diff --git a/src/tint/lang/msl/writer/raise/packed_vec3_test.cc b/src/tint/lang/msl/writer/raise/packed_vec3_test.cc
index e7704aa..081e5d0 100644
--- a/src/tint/lang/msl/writer/raise/packed_vec3_test.cc
+++ b/src/tint/lang/msl/writer/raise/packed_vec3_test.cc
@@ -3692,5 +3692,137 @@
     EXPECT_EQ(expect, str());
 }
 
+TEST_F(MslWriter_PackedVec3Test, AtomicOnPackedStructMember) {
+    auto* s = ty.Struct(mod.symbols.New("S"), {
+                                                  {mod.symbols.Register("vec"), ty.vec3<u32>()},
+                                                  {mod.symbols.Register("u"), ty.atomic<u32>()},
+                                              });
+
+    auto* var = b.Var("v", ty.ptr<workgroup>(s));
+    mod.root_block->Append(var);
+
+    auto* func = b.Function("foo", ty.u32());
+    b.Append(func->Block(), [&] {  //
+        auto* p = b.Access<ptr<workgroup, atomic<u32>>>(var, 1_u);
+        auto* result = b.Call<u32>(core::BuiltinFn::kAtomicLoad, p);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+S = struct @align(16) {
+  vec:vec3<u32> @offset(0)
+  u:atomic<u32> @offset(12)
+}
+
+$B1: {  # root
+  %v:ptr<workgroup, S, read_write> = var
+}
+
+%foo = func():u32 {
+  $B2: {
+    %3:ptr<workgroup, atomic<u32>, read_write> = access %v, 1u
+    %4:u32 = atomicLoad %3
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+S = struct @align(16) {
+  vec:vec3<u32> @offset(0)
+  u:atomic<u32> @offset(12)
+}
+
+S_packed_vec3 = struct @align(16) {
+  vec:__packed_vec3<u32> @offset(0)
+  u:atomic<u32> @offset(12)
+}
+
+$B1: {  # root
+  %v:ptr<workgroup, S_packed_vec3, read_write> = var
+}
+
+%foo = func():u32 {
+  $B2: {
+    %3:ptr<workgroup, atomic<u32>, read_write> = access %v, 1u
+    %4:u32 = atomicLoad %3
+    ret %4
+  }
+}
+)";
+
+    Run(PackedVec3);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(MslWriter_PackedVec3Test, AtomicOnPackedStructMember_ViaLet) {
+    auto* s = ty.Struct(mod.symbols.New("S"), {
+                                                  {mod.symbols.Register("vec"), ty.vec3<u32>()},
+                                                  {mod.symbols.Register("u"), ty.atomic<u32>()},
+                                              });
+
+    auto* var = b.Var("v", ty.ptr<workgroup>(s));
+    mod.root_block->Append(var);
+
+    auto* func = b.Function("foo", ty.u32());
+    b.Append(func->Block(), [&] {  //
+        auto* p = b.Let("p", b.Access<ptr<workgroup, atomic<u32>>>(var, 1_u));
+        auto* result = b.Call<u32>(core::BuiltinFn::kAtomicLoad, p);
+        b.Return(func, result);
+    });
+
+    auto* src = R"(
+S = struct @align(16) {
+  vec:vec3<u32> @offset(0)
+  u:atomic<u32> @offset(12)
+}
+
+$B1: {  # root
+  %v:ptr<workgroup, S, read_write> = var
+}
+
+%foo = func():u32 {
+  $B2: {
+    %3:ptr<workgroup, atomic<u32>, read_write> = access %v, 1u
+    %p:ptr<workgroup, atomic<u32>, read_write> = let %3
+    %5:u32 = atomicLoad %p
+    ret %5
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+S = struct @align(16) {
+  vec:vec3<u32> @offset(0)
+  u:atomic<u32> @offset(12)
+}
+
+S_packed_vec3 = struct @align(16) {
+  vec:__packed_vec3<u32> @offset(0)
+  u:atomic<u32> @offset(12)
+}
+
+$B1: {  # root
+  %v:ptr<workgroup, S_packed_vec3, read_write> = var
+}
+
+%foo = func():u32 {
+  $B2: {
+    %3:ptr<workgroup, atomic<u32>, read_write> = access %v, 1u
+    %p:ptr<workgroup, atomic<u32>, read_write> = let %3
+    %5:u32 = atomicLoad %p
+    ret %5
+  }
+}
+)";
+
+    Run(PackedVec3);
+
+    EXPECT_EQ(expect, str());
+}
+
 }  // namespace
 }  // namespace tint::msl::writer::raise
diff --git a/test/tint/bug/tint/366314931.wgsl b/test/tint/bug/tint/366314931.wgsl
new file mode 100644
index 0000000..ca9864d
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl
@@ -0,0 +1,15 @@
+struct S {
+  v : vec3u,
+  u : atomic<u32>,
+}
+
+var<workgroup> wgvar: S;
+
+@group(0) @binding(0)
+var<storage, read_write> output: S;
+
+@compute @workgroup_size(1,1,1)
+fn main() {
+  let x = atomicLoad(&wgvar.u);
+  atomicStore(&output.u, x);
+}
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/366314931.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..2936f68
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.dxc.hlsl
@@ -0,0 +1,41 @@
+struct S {
+  uint3 v;
+  uint u;
+};
+
+groupshared S wgvar;
+
+void tint_zero_workgroup_memory(uint local_idx) {
+  if ((local_idx < 1u)) {
+    wgvar.v = (0u).xxx;
+    uint atomic_result = 0u;
+    InterlockedExchange(wgvar.u, 0u, atomic_result);
+  }
+  GroupMemoryBarrierWithGroupSync();
+}
+
+RWByteAddressBuffer output : register(u0);
+
+struct tint_symbol_1 {
+  uint local_invocation_index : SV_GroupIndex;
+};
+
+void outputatomicStore(uint offset, uint value) {
+  uint ignored;
+  output.InterlockedExchange(offset, value, ignored);
+}
+
+
+void main_inner(uint local_invocation_index) {
+  tint_zero_workgroup_memory(local_invocation_index);
+  uint atomic_result_1 = 0u;
+  InterlockedOr(wgvar.u, 0, atomic_result_1);
+  uint x = atomic_result_1;
+  outputatomicStore(12u, x);
+}
+
+[numthreads(1, 1, 1)]
+void main(tint_symbol_1 tint_symbol) {
+  main_inner(tint_symbol.local_invocation_index);
+  return;
+}
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/366314931.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..2936f68
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.fxc.hlsl
@@ -0,0 +1,41 @@
+struct S {
+  uint3 v;
+  uint u;
+};
+
+groupshared S wgvar;
+
+void tint_zero_workgroup_memory(uint local_idx) {
+  if ((local_idx < 1u)) {
+    wgvar.v = (0u).xxx;
+    uint atomic_result = 0u;
+    InterlockedExchange(wgvar.u, 0u, atomic_result);
+  }
+  GroupMemoryBarrierWithGroupSync();
+}
+
+RWByteAddressBuffer output : register(u0);
+
+struct tint_symbol_1 {
+  uint local_invocation_index : SV_GroupIndex;
+};
+
+void outputatomicStore(uint offset, uint value) {
+  uint ignored;
+  output.InterlockedExchange(offset, value, ignored);
+}
+
+
+void main_inner(uint local_invocation_index) {
+  tint_zero_workgroup_memory(local_invocation_index);
+  uint atomic_result_1 = 0u;
+  InterlockedOr(wgvar.u, 0, atomic_result_1);
+  uint x = atomic_result_1;
+  outputatomicStore(12u, x);
+}
+
+[numthreads(1, 1, 1)]
+void main(tint_symbol_1 tint_symbol) {
+  main_inner(tint_symbol.local_invocation_index);
+  return;
+}
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.glsl b/test/tint/bug/tint/366314931.wgsl.expected.glsl
new file mode 100644
index 0000000..99f652a
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.glsl
@@ -0,0 +1,31 @@
+#version 310 es
+
+struct S {
+  uvec3 v;
+  uint u;
+};
+
+shared S wgvar;
+void tint_zero_workgroup_memory(uint local_idx) {
+  if ((local_idx < 1u)) {
+    wgvar.v = uvec3(0u);
+    atomicExchange(wgvar.u, 0u);
+  }
+  barrier();
+}
+
+layout(binding = 0, std430) buffer tint_symbol_block_ssbo {
+  S inner;
+} tint_symbol;
+
+void tint_symbol_1(uint local_invocation_index) {
+  tint_zero_workgroup_memory(local_invocation_index);
+  uint x = atomicOr(wgvar.u, 0u);
+  atomicExchange(tint_symbol.inner.u, x);
+}
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void main() {
+  tint_symbol_1(gl_LocalInvocationIndex);
+  return;
+}
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.ir.dxc.hlsl b/test/tint/bug/tint/366314931.wgsl.expected.ir.dxc.hlsl
new file mode 100644
index 0000000..42c976c
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.ir.dxc.hlsl
@@ -0,0 +1,31 @@
+struct S {
+  uint3 v;
+  uint u;
+};
+
+struct main_inputs {
+  uint tint_local_index : SV_GroupIndex;
+};
+
+
+groupshared S wgvar;
+RWByteAddressBuffer output : register(u0);
+void main_inner(uint tint_local_index) {
+  if ((tint_local_index == 0u)) {
+    wgvar.v = (0u).xxx;
+    uint v_1 = 0u;
+    InterlockedExchange(wgvar.u, 0u, v_1);
+  }
+  GroupMemoryBarrierWithGroupSync();
+  uint v_2 = 0u;
+  InterlockedOr(wgvar.u, 0u, v_2);
+  uint x = v_2;
+  uint v_3 = 0u;
+  output.InterlockedExchange(uint(12u), x, v_3);
+}
+
+[numthreads(1, 1, 1)]
+void main(main_inputs inputs) {
+  main_inner(inputs.tint_local_index);
+}
+
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.ir.fxc.hlsl b/test/tint/bug/tint/366314931.wgsl.expected.ir.fxc.hlsl
new file mode 100644
index 0000000..42c976c
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.ir.fxc.hlsl
@@ -0,0 +1,31 @@
+struct S {
+  uint3 v;
+  uint u;
+};
+
+struct main_inputs {
+  uint tint_local_index : SV_GroupIndex;
+};
+
+
+groupshared S wgvar;
+RWByteAddressBuffer output : register(u0);
+void main_inner(uint tint_local_index) {
+  if ((tint_local_index == 0u)) {
+    wgvar.v = (0u).xxx;
+    uint v_1 = 0u;
+    InterlockedExchange(wgvar.u, 0u, v_1);
+  }
+  GroupMemoryBarrierWithGroupSync();
+  uint v_2 = 0u;
+  InterlockedOr(wgvar.u, 0u, v_2);
+  uint x = v_2;
+  uint v_3 = 0u;
+  output.InterlockedExchange(uint(12u), x, v_3);
+}
+
+[numthreads(1, 1, 1)]
+void main(main_inputs inputs) {
+  main_inner(inputs.tint_local_index);
+}
+
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.ir.glsl b/test/tint/bug/tint/366314931.wgsl.expected.ir.glsl
new file mode 100644
index 0000000..90455fa
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.ir.glsl
@@ -0,0 +1,11 @@
+SKIP: FAILED
+
+../../src/tint/lang/glsl/writer/printer/printer.cc:1394 internal compiler error: TINT_UNREACHABLE unhandled core builtin: atomicStore
+********************************************************************
+*  The tint shader compiler has encountered an unexpected error.   *
+*                                                                  *
+*  Please help us fix this issue by submitting a bug report at     *
+*  crbug.com/tint with the source program that triggered the bug.  *
+********************************************************************
+
+tint executable returned error: signal: trace/BPT trap
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.ir.msl b/test/tint/bug/tint/366314931.wgsl.expected.ir.msl
new file mode 100644
index 0000000..01df121
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.ir.msl
@@ -0,0 +1,31 @@
+#include <metal_stdlib>
+using namespace metal;
+
+struct S_packed_vec3 {
+  /* 0x0000 */ packed_uint3 v;
+  /* 0x000c */ atomic_uint u;
+};
+
+struct tint_module_vars_struct {
+  threadgroup S_packed_vec3* wgvar;
+  device S_packed_vec3* output;
+};
+
+struct tint_symbol_2 {
+  S_packed_vec3 tint_symbol_1;
+};
+
+void tint_symbol_inner(uint tint_local_index, tint_module_vars_struct tint_module_vars) {
+  if ((tint_local_index == 0u)) {
+    (*tint_module_vars.wgvar).v = packed_uint3(uint3(0u));
+    atomic_store_explicit((&(*tint_module_vars.wgvar).u), 0u, memory_order_relaxed);
+  }
+  threadgroup_barrier(mem_flags::mem_threadgroup);
+  uint const x = atomic_load_explicit((&(*tint_module_vars.wgvar).u), memory_order_relaxed);
+  atomic_store_explicit((&(*tint_module_vars.output).u), x, memory_order_relaxed);
+}
+
+kernel void tint_symbol(uint tint_local_index [[thread_index_in_threadgroup]], threadgroup tint_symbol_2* v_1 [[threadgroup(0)]], device S_packed_vec3* output [[buffer(0)]]) {
+  tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.wgvar=(&(*v_1).tint_symbol_1), .output=output};
+  tint_symbol_inner(tint_local_index, tint_module_vars);
+}
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.msl b/test/tint/bug/tint/366314931.wgsl.expected.msl
new file mode 100644
index 0000000..2b6f1f3
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.msl
@@ -0,0 +1,33 @@
+#include <metal_stdlib>
+
+using namespace metal;
+struct S_tint_packed_vec3 {
+  /* 0x0000 */ packed_uint3 v;
+  /* 0x000c */ atomic_uint u;
+};
+
+struct S {
+  uint3 v;
+  atomic_uint u;
+};
+
+void tint_zero_workgroup_memory(uint local_idx, threadgroup S* const tint_symbol_1) {
+  if ((local_idx < 1u)) {
+    (*(tint_symbol_1)).v = uint3(0u);
+    atomic_store_explicit(&((*(tint_symbol_1)).u), 0u, memory_order_relaxed);
+  }
+  threadgroup_barrier(mem_flags::mem_threadgroup);
+}
+
+void tint_symbol_inner(uint local_invocation_index, threadgroup S* const tint_symbol_2, device S_tint_packed_vec3* const tint_symbol_3) {
+  tint_zero_workgroup_memory(local_invocation_index, tint_symbol_2);
+  uint const x = atomic_load_explicit(&((*(tint_symbol_2)).u), memory_order_relaxed);
+  atomic_store_explicit(&((*(tint_symbol_3)).u), x, memory_order_relaxed);
+}
+
+kernel void tint_symbol(device S_tint_packed_vec3* tint_symbol_5 [[buffer(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
+  threadgroup S tint_symbol_4;
+  tint_symbol_inner(local_invocation_index, &(tint_symbol_4), tint_symbol_5);
+  return;
+}
+
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.spvasm b/test/tint/bug/tint/366314931.wgsl.expected.spvasm
new file mode 100644
index 0000000..3000342
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.spvasm
@@ -0,0 +1,76 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 1
+; Bound: 41
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main" %main_local_invocation_index_Input
+               OpExecutionMode %main LocalSize 1 1 1
+               OpMemberName %S 0 "v"
+               OpMemberName %S 1 "u"
+               OpName %S "S"
+               OpName %wgvar "wgvar"
+               OpMemberName %tint_symbol_1 0 "tint_symbol"
+               OpName %tint_symbol_1 "tint_symbol_1"
+               OpName %main_local_invocation_index_Input "main_local_invocation_index_Input"
+               OpName %main_inner "main_inner"
+               OpName %tint_local_index "tint_local_index"
+               OpName %x "x"
+               OpName %main "main"
+               OpMemberDecorate %S 0 Offset 0
+               OpMemberDecorate %S 1 Offset 12
+               OpMemberDecorate %tint_symbol_1 0 Offset 0
+               OpDecorate %tint_symbol_1 Block
+               OpDecorate %6 DescriptorSet 0
+               OpDecorate %6 Binding 0
+               OpDecorate %6 Coherent
+               OpDecorate %main_local_invocation_index_Input BuiltIn LocalInvocationIndex
+       %uint = OpTypeInt 32 0
+     %v3uint = OpTypeVector %uint 3
+          %S = OpTypeStruct %v3uint %uint
+%_ptr_Workgroup_S = OpTypePointer Workgroup %S
+      %wgvar = OpVariable %_ptr_Workgroup_S Workgroup
+%tint_symbol_1 = OpTypeStruct %S
+%_ptr_StorageBuffer_tint_symbol_1 = OpTypePointer StorageBuffer %tint_symbol_1
+          %6 = OpVariable %_ptr_StorageBuffer_tint_symbol_1 StorageBuffer
+%_ptr_Input_uint = OpTypePointer Input %uint
+%main_local_invocation_index_Input = OpVariable %_ptr_Input_uint Input
+       %void = OpTypeVoid
+         %14 = OpTypeFunction %void %uint
+     %uint_0 = OpConstant %uint 0
+       %bool = OpTypeBool
+%_ptr_Workgroup_v3uint = OpTypePointer Workgroup %v3uint
+         %23 = OpConstantNull %v3uint
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+     %uint_1 = OpConstant %uint 1
+     %uint_2 = OpConstant %uint 2
+   %uint_264 = OpConstant %uint 264
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+         %37 = OpTypeFunction %void
+ %main_inner = OpFunction %void None %14
+%tint_local_index = OpFunctionParameter %uint
+         %15 = OpLabel
+         %16 = OpIEqual %bool %tint_local_index %uint_0
+               OpSelectionMerge %19 None
+               OpBranchConditional %16 %20 %19
+         %20 = OpLabel
+         %21 = OpAccessChain %_ptr_Workgroup_v3uint %wgvar %uint_0
+               OpStore %21 %23 None
+         %24 = OpAccessChain %_ptr_Workgroup_uint %wgvar %uint_1
+               OpAtomicStore %24 %uint_2 %uint_0 %uint_0
+               OpBranch %19
+         %19 = OpLabel
+               OpControlBarrier %uint_2 %uint_2 %uint_264
+         %31 = OpAccessChain %_ptr_Workgroup_uint %wgvar %uint_1
+          %x = OpAtomicLoad %uint %31 %uint_2 %uint_0
+         %33 = OpAccessChain %_ptr_StorageBuffer_uint %6 %uint_0 %uint_1
+               OpAtomicStore %33 %uint_1 %uint_0 %x
+               OpReturn
+               OpFunctionEnd
+       %main = OpFunction %void None %37
+         %38 = OpLabel
+         %39 = OpLoad %uint %main_local_invocation_index_Input None
+         %40 = OpFunctionCall %void %main_inner %39
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/bug/tint/366314931.wgsl.expected.wgsl b/test/tint/bug/tint/366314931.wgsl.expected.wgsl
new file mode 100644
index 0000000..971caad
--- /dev/null
+++ b/test/tint/bug/tint/366314931.wgsl.expected.wgsl
@@ -0,0 +1,14 @@
+struct S {
+  v : vec3u,
+  u : atomic<u32>,
+}
+
+var<workgroup> wgvar : S;
+
+@group(0) @binding(0) var<storage, read_write> output : S;
+
+@compute @workgroup_size(1, 1, 1)
+fn main() {
+  let x = atomicLoad(&(wgvar.u));
+  atomicStore(&(output.u), x);
+}