[spirv-reader][ir] Support propagating atomics through function params.

Add support for atomics passed as functions parameters to the spir-v IR
reader.

Bug: 406616153
Change-Id: I5170875b3e8e34aa4aa793fcc5ba2a57586b260f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/235615
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/lower/atomics.cc b/src/tint/lang/spirv/reader/lower/atomics.cc
index 244d8ea..7c422ea 100644
--- a/src/tint/lang/spirv/reader/lower/atomics.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics.cc
@@ -198,21 +198,38 @@
     }
 
     void ConvertAtomicValue(core::ir::Value* val) {
-        auto* res = val->As<core::ir::InstructionResult>();
-        TINT_ASSERT(res);
+        tint::Switch(  //
+            val,       //
+            [&](core::ir::InstructionResult* res) {
+                auto* orig_ty = res->Type();
+                auto* atomic_ty = AtomicTypeFor(val, orig_ty);
+                res->SetType(atomic_ty);
 
-        auto* orig_ty = res->Type();
-        auto* atomic_ty = AtomicTypeFor(val, orig_ty);
-        res->SetType(atomic_ty);
+                tint::Switch(            //
+                    res->Instruction(),  //
+                    [&](core::ir::Access* a) {
+                        CheckForStructForking(a);
+                        values_to_convert_.Push(a->Object());
+                    },                                                               //
+                    [&](core::ir::Let* l) { values_to_convert_.Push(l->Value()); },  //
+                    [&](core::ir::Var*) {},                                          //
+                    TINT_ICE_ON_NO_MATCH);
+            },
+            [&](core::ir::FunctionParam* param) {
+                auto* orig_ty = param->Type();
+                auto* atomic_ty = AtomicTypeFor(val, orig_ty);
+                param->SetType(atomic_ty);
 
-        tint::Switch(            //
-            res->Instruction(),  //
-            [&](core::ir::Access* a) {
-                CheckForStructForking(a);
-                values_to_convert_.Push(a->Object());
-            },                                                               //
-            [&](core::ir::Let* l) { values_to_convert_.Push(l->Value()); },  //
-            [&](core::ir::Var*) {},                                          //
+                for (auto& usage : param->Function()->UsagesUnsorted()) {
+                    if (usage->instruction->Is<core::ir::Return>()) {
+                        continue;
+                    }
+
+                    auto* call = usage->instruction->As<core::ir::Call>();
+                    TINT_ASSERT(call);
+                    values_to_convert_.Push(call->Args()[param->Index()]);
+                }
+            },
             TINT_ICE_ON_NO_MATCH);
     }
 
diff --git a/src/tint/lang/spirv/reader/lower/atomics_test.cc b/src/tint/lang/spirv/reader/lower/atomics_test.cc
index 95a82a5..0bf3fa7 100644
--- a/src/tint/lang/spirv/reader/lower/atomics_test.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics_test.cc
@@ -691,6 +691,75 @@
     ASSERT_EQ(expect, str());
 }
 
+TEST_F(SpirvReader_AtomicsTest, FunctionParam) {
+    auto* c = b.Function("c", ty.void_());
+    auto* p = b.FunctionParam("param", ty.ptr(workgroup, ty.array<u32, 4>(), read_write));
+    c->SetParams({p});
+
+    b.Append(c->Block(), [&] {
+        auto* a = b.Access(ty.ptr<workgroup, u32, read_write>(), p, 1_i);
+        b.Call<spirv::ir::BuiltinCall>(ty.void_(), spirv::BuiltinFn::kAtomicStore, a, 2_u, 0_u,
+                                       1_u);
+
+        b.Return(c);
+    });
+
+    auto* f = b.ComputeFunction("main");
+
+    core::ir::Var* wg = nullptr;
+    b.Append(mod.root_block,
+             [&] { wg = b.Var("wg", ty.ptr(workgroup, ty.array<u32, 4>(), read_write)); });
+
+    b.Append(f->Block(), [&] {  //
+        b.Call(ty.void_(), c, wg);
+        b.Return(f);
+    });
+
+    auto* src = R"(
+$B1: {  # root
+  %wg:ptr<workgroup, array<u32, 4>, read_write> = var undef
+}
+
+%c = func(%param:ptr<workgroup, array<u32, 4>, read_write>):void {
+  $B2: {
+    %4:ptr<workgroup, u32, read_write> = access %param, 1i
+    %5:void = spirv.atomic_store %4, 2u, 0u, 1u
+    ret
+  }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B3: {
+    %7:void = call %c, %wg
+    ret
+  }
+}
+)";
+
+    ASSERT_EQ(src, str());
+    Run(Atomics);
+
+    auto* expect = R"(
+$B1: {  # root
+  %wg:ptr<workgroup, array<atomic<u32>, 4>, read_write> = var undef
+}
+
+%c = func(%param:ptr<workgroup, array<atomic<u32>, 4>, read_write>):void {
+  $B2: {
+    %4:ptr<workgroup, atomic<u32>, read_write> = access %param, 1i
+    %5:void = atomicStore %4, 1u
+    ret
+  }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B3: {
+    %7:void = call %c, %wg
+    ret
+  }
+}
+)";
+    ASSERT_EQ(expect, str());
+}
+
 TEST_F(SpirvReader_AtomicsTest, AtomicAdd) {
     auto* f = b.ComputeFunction("main");
 
diff --git a/src/tint/lang/spirv/reader/parser/atomics_test.cc b/src/tint/lang/spirv/reader/parser/atomics_test.cc
index 5e71efc..70344e3 100644
--- a/src/tint/lang/spirv/reader/parser/atomics_test.cc
+++ b/src/tint/lang/spirv/reader/parser/atomics_test.cc
@@ -577,6 +577,125 @@
 )");
 }
 
+TEST_F(SpirvParser_AtomicsTest, FunctionParam) {
+    EXPECT_IR(R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpName %wg "wg"
+               OpName %main "main"
+        %int = OpTypeInt 32 1
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+     %uint_2 = OpConstant %uint 2
+     %uint_4 = OpConstant %uint 4
+      %int_1 = OpConstant %int 1
+        %arr = OpTypeArray %uint %uint_4
+    %ptr_arr = OpTypePointer Workgroup %arr
+   %ptr_uint = OpTypePointer Workgroup %uint
+       %void = OpTypeVoid
+         %10 = OpTypeFunction %void
+         %11 = OpTypeFunction %void %ptr_arr
+         %wg = OpVariable %ptr_arr Workgroup
+
+        %foo = OpFunction %void None %11
+      %param = OpFunctionParameter %ptr_arr
+  %foo_start = OpLabel
+         %42 = OpAccessChain %ptr_uint %param %int_1
+               OpAtomicStore %42 %uint_2 %uint_0 %uint_1
+               OpReturn
+               OpFunctionEnd
+
+       %main = OpFunction %void None %10
+         %45 = OpLabel
+         %44 = OpFunctionCall %void %foo %wg
+               OpReturn
+               OpFunctionEnd
+)",
+              R"(
+$B1: {  # root
+  %wg:ptr<workgroup, array<u32, 4>, read_write> = var undef
+}
+
+%2 = func(%3:ptr<workgroup, array<u32, 4>, read_write>):void {
+  $B2: {
+    %4:ptr<workgroup, u32, read_write> = access %3, 1i
+    %5:void = spirv.atomic_store %4, 2u, 0u, 1u
+    ret
+  }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B3: {
+    %7:void = call %2, %wg
+    ret
+  }
+}
+)");
+}
+
+// TODO(dsinclair): Requires support for variable pointers
+TEST_F(SpirvParser_AtomicsTest, DISABLED_FunctionParam_subpointer) {
+    EXPECT_IR(R"(
+               OpCapability Shader
+               OpCapability VariablePointers
+               OpExtension "SPV_KHR_variable_pointers"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpName %wg "wg"
+               OpName %main "main"
+        %int = OpTypeInt 32 1
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+     %uint_2 = OpConstant %uint 2
+     %uint_4 = OpConstant %uint 4
+      %int_1 = OpConstant %int 1
+        %arr = OpTypeArray %uint %uint_4
+    %ptr_arr = OpTypePointer Workgroup %arr
+   %ptr_uint = OpTypePointer Workgroup %uint
+       %void = OpTypeVoid
+         %10 = OpTypeFunction %void
+         %11 = OpTypeFunction %void %ptr_uint
+         %wg = OpVariable %ptr_arr Workgroup
+
+        %foo = OpFunction %void None %11
+      %param = OpFunctionParameter %ptr_uint
+  %foo_start = OpLabel
+               OpAtomicStore %param %uint_2 %uint_0 %uint_1
+               OpReturn
+               OpFunctionEnd
+
+       %main = OpFunction %void None %10
+         %45 = OpLabel
+         %42 = OpAccessChain %ptr_uint %wg %int_1
+         %44 = OpFunctionCall %void %foo %42
+               OpReturn
+               OpFunctionEnd
+)",
+              R"(
+$B1: {  # root
+  %wg:ptr<workgroup, array<u32, 4>, read_write> = var undef
+}
+
+%2 = func(%3:ptr<workgroup, array<u32, 4>, read_write>):void {
+  $B2: {
+    %4:ptr<workgroup, u32, read_write> = access %3, 1i
+    %5:void = spirv.atomic_store %4, 2u, 0u, 1u
+    ret
+  }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B3: {
+    %7:void = call %2, %wg
+    ret
+  }
+}
+)");
+}
+
 TEST_F(SpirvParser_AtomicsTest, AtomicAdd) {
     EXPECT_IR(R"(
                OpCapability Shader