[spirv-reader][ir] Handle UserCall atomic conversions.

When processing `UserCall` instructions for atomic conversions we may
need to fork the called functions. In order to fork, we have to collect
up all of the parameters which become atomics in that specific user call
and make a custom forked function. We store this combination of
(function, parameters converted) so we only fork the function one for
each combination of argument.

Bug: 404501988
Change-Id: I6bd318151d9e611ffae038b14c4598606f48a3ef
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/238594
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/spirv/reader/lower/atomics.cc b/src/tint/lang/spirv/reader/lower/atomics.cc
index e550296..0a52bd6 100644
--- a/src/tint/lang/spirv/reader/lower/atomics.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics.cc
@@ -30,13 +30,13 @@
 #include <utility>
 
 #include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/clone_context.h"
 #include "src/tint/lang/core/ir/module.h"
 #include "src/tint/lang/core/ir/validator.h"
 #include "src/tint/lang/core/type/builtin_structs.h"
 #include "src/tint/lang/spirv/ir/builtin_call.h"
 #include "src/tint/utils/containers/hashmap.h"
 #include "src/tint/utils/containers/hashset.h"
-#include "src/tint/utils/containers/vector.h"
 
 namespace tint::spirv::reader::lower {
 namespace {
@@ -62,9 +62,18 @@
     /// loads/stores updated to match. This maps to the root FunctionParam or Var for each atomic.
     Vector<core::ir::Value*, 8> values_to_fix_usages_{};
 
+    /// Any `ir::UserCall` instructions which have atomic params which need to
+    /// be updated.
+    Hashset<core::ir::UserCall*, 2> user_calls_to_convert_{};
+
     /// The `ir::Value`s which have been converted
     Hashset<core::ir::Value*, 8> converted_{};
 
+    /// Function to atomic replacements, this is done by hashcode since the
+    /// function pointer is combined with the parameters which are converted to
+    /// atomics.
+    Hashmap<size_t, core::ir::Function*, 4> func_hash_to_func_{};
+
     struct ForkedStruct {
         const core::type::Struct* src_struct = nullptr;
         const core::type::Struct* dst_struct = nullptr;
@@ -151,9 +160,22 @@
         ProcessForkedStructs();
         ReplaceStructTypes();
 
+        // The double loop happens because when we convert user calls, that will
+        // add more values to convert, but those values can find user calls to
+        // convert, so we have to work until we stabilize
         while (!values_to_fix_usages_.IsEmpty()) {
-            auto* val = values_to_fix_usages_.Pop();
-            ConvertUsagesToAtomic(val);
+            while (!values_to_fix_usages_.IsEmpty()) {
+                auto* val = values_to_fix_usages_.Pop();
+                ConvertUsagesToAtomic(val);
+            }
+
+            auto user_calls = user_calls_to_convert_.Vector();
+            // Sort for deterministic output
+            user_calls.Sort();
+            for (auto& call : user_calls) {
+                ConvertUserCall(call);
+            }
+            user_calls_to_convert_.Clear();
         }
     }
 
@@ -353,9 +375,7 @@
                         values_to_fix_usages_.Push(res);
                     }
                 },
-                [&](core::ir::UserCall*) {
-                    // This should have been handled as a function parameter above.
-                },
+                [&](core::ir::UserCall* uc) { user_calls_to_convert_.Add(uc); },
                 [&](core::ir::CoreBuiltinCall* bc) {
                     // This was converted when we switched from a SPIR-V intrinsic to core
                     TINT_ASSERT(core::IsAtomic(bc->Func()));
@@ -365,6 +385,46 @@
         });
     }
 
+    // The user calls need to check all of the parameters which were converted
+    // to atomics and create a forked function call for that combination of
+    // parameters.
+    void ConvertUserCall(core::ir::UserCall* uc) {
+        auto* target = uc->Target();
+        auto& params = target->Params();
+        const auto& args = uc->Args();
+
+        Vector<size_t, 2> to_convert;
+        for (size_t i = 0; i < args.Length(); ++i) {
+            if (params[i]->Type() != args[i]->Type()) {
+                to_convert.Push(i);
+            }
+        }
+        // Everything is already converted we're done.
+        if (to_convert.IsEmpty()) {
+            return;
+        }
+
+        // Hash based on the original function pointer and the specific
+        // parameters we're converting.
+        auto hash = Hash(target);
+        hash = HashCombine(hash, to_convert);
+
+        auto* new_fn = func_hash_to_func_.GetOrAdd(hash, [&] {
+            core::ir::CloneContext ctx{ir};
+            auto* fn = uc->Target()->Clone(ctx);
+            ir.functions.Push(fn);
+
+            for (auto idx : to_convert) {
+                auto* p = fn->Params()[idx];
+                p->SetType(args[idx]->Type());
+
+                values_to_fix_usages_.Push(p);
+            }
+            return fn;
+        });
+        uc->SetTarget(new_fn);
+    }
+
     const core::type::Type* TypeForAccess(core::ir::Access* access) {
         auto* ptr = access->Object()->Type()->As<core::type::Pointer>();
         TINT_ASSERT(ptr);
diff --git a/src/tint/lang/spirv/reader/lower/atomics_test.cc b/src/tint/lang/spirv/reader/lower/atomics_test.cc
index 432b875..a5ef867 100644
--- a/src/tint/lang/spirv/reader/lower/atomics_test.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics_test.cc
@@ -2952,5 +2952,256 @@
     ASSERT_EQ(expect, str());
 }
 
+TEST_F(SpirvReader_AtomicsTest, FunctionParam_AnotherCallWithNonAtomicUse) {
+    core::ir::Var* wg_atomic = nullptr;
+    core::ir::Var* wg_nonatomic = nullptr;
+    b.Append(mod.root_block, [&] {
+        wg_atomic = b.Var("wg_atomic", ty.ptr<workgroup, u32>());
+        wg_nonatomic = b.Var("wg_nonatomic", ty.ptr<workgroup, u32>());
+    });
+
+    auto* f_atomic = b.Function("f_atomic", ty.u32());
+    b.Append(f_atomic->Block(), [&] {
+        auto* p = b.FunctionParam("param", ty.ptr<workgroup, u32>());
+        f_atomic->SetParams({p});
+
+        auto* ret =
+            b.Call<spirv::ir::BuiltinCall>(ty.u32(), spirv::BuiltinFn::kAtomicLoad, p, 1_u, 0_u);
+        b.Return(f_atomic, ret);
+    });
+
+    auto* f_nonatomic = b.Function("f_nonatomic", ty.u32());
+    b.Append(f_nonatomic->Block(), [&] {
+        auto* p = b.FunctionParam("param", ty.ptr<workgroup, u32>());
+        f_nonatomic->SetParams({p});
+
+        auto* ret = b.Load(p);
+        b.Return(f_nonatomic, ret);
+    });
+
+    auto* main = b.ComputeFunction("main");
+    b.Append(main->Block(), [&] {  //
+        b.Call(ty.u32(), f_atomic, wg_atomic);
+        b.Call(ty.u32(), f_nonatomic, wg_atomic);
+        b.Call(ty.u32(), f_nonatomic, wg_atomic);
+        b.Call(ty.u32(), f_nonatomic, wg_nonatomic);
+        b.Return(main);
+    });
+
+    auto* src = R"(
+$B1: {  # root
+  %wg_atomic:ptr<workgroup, u32, read_write> = var undef
+  %wg_nonatomic:ptr<workgroup, u32, read_write> = var undef
+}
+
+%f_atomic = func(%param:ptr<workgroup, u32, read_write>):u32 {
+  $B2: {
+    %5:u32 = spirv.atomic_load %param, 1u, 0u
+    ret %5
+  }
+}
+%f_nonatomic = func(%param_1:ptr<workgroup, u32, read_write>):u32 {  # %param_1: 'param'
+  $B3: {
+    %8:u32 = load %param_1
+    ret %8
+  }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B4: {
+    %10:u32 = call %f_atomic, %wg_atomic
+    %11:u32 = call %f_nonatomic, %wg_atomic
+    %12:u32 = call %f_nonatomic, %wg_atomic
+    %13:u32 = call %f_nonatomic, %wg_nonatomic
+    ret
+  }
+}
+)";
+
+    ASSERT_EQ(src, str());
+    Run(Atomics);
+
+    auto* expect = R"(
+$B1: {  # root
+  %wg_atomic:ptr<workgroup, atomic<u32>, read_write> = var undef
+  %wg_nonatomic:ptr<workgroup, u32, read_write> = var undef
+}
+
+%f_atomic = func(%param:ptr<workgroup, atomic<u32>, read_write>):u32 {
+  $B2: {
+    %5:u32 = atomicLoad %param
+    ret %5
+  }
+}
+%f_nonatomic = func(%param_1:ptr<workgroup, u32, read_write>):u32 {  # %param_1: 'param'
+  $B3: {
+    %8:u32 = load %param_1
+    ret %8
+  }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B4: {
+    %10:u32 = call %f_atomic, %wg_atomic
+    %11:u32 = call %f_nonatomic_1, %wg_atomic
+    %13:u32 = call %f_nonatomic_1, %wg_atomic
+    %14:u32 = call %f_nonatomic, %wg_nonatomic
+    ret
+  }
+}
+%f_nonatomic_1 = func(%param_2:ptr<workgroup, atomic<u32>, read_write>):u32 {  # %f_nonatomic_1: 'f_nonatomic', %param_2: 'param'
+  $B5: {
+    %16:u32 = atomicLoad %param_2
+    ret %16
+  }
+}
+)";
+    ASSERT_EQ(expect, str());
+}
+
+TEST_F(SpirvReader_AtomicsTest, FunctionParam_MixedCalls) {
+    core::ir::Var* wg_atomic = nullptr;
+    core::ir::Var* wg_nonatomic = nullptr;
+    b.Append(mod.root_block, [&] {
+        wg_atomic = b.Var("wg_atomic", ty.ptr<workgroup, u32>());
+        wg_nonatomic = b.Var("wg_nonatomic", ty.ptr<workgroup, u32>());
+    });
+
+    auto* f_atomic = b.Function("f_atomic", ty.u32());
+    b.Append(f_atomic->Block(), [&] {
+        auto* p = b.FunctionParam("param", ty.ptr<workgroup, u32>());
+        f_atomic->SetParams({p});
+
+        auto* ret =
+            b.Call<spirv::ir::BuiltinCall>(ty.u32(), spirv::BuiltinFn::kAtomicLoad, p, 1_u, 0_u);
+        b.Return(f_atomic, ret);
+    });
+
+    auto* f_nonatomic = b.Function("f_nonatomic", ty.u32());
+    b.Append(f_nonatomic->Block(), [&] {
+        auto* p1 = b.FunctionParam("param1", ty.ptr<workgroup, u32>());
+        auto* p2 = b.FunctionParam("param2", ty.ptr<workgroup, u32>());
+        f_nonatomic->SetParams({p1, p2});
+
+        auto* one = b.Load(p1);
+        auto* two = b.Load(p2);
+        b.Return(f_nonatomic, b.Add(ty.u32(), one, two));
+    });
+
+    auto* main = b.ComputeFunction("main");
+    b.Append(main->Block(), [&] {  //
+        b.Call(ty.u32(), f_atomic, wg_atomic);
+        b.Call(ty.u32(), f_nonatomic, wg_atomic, wg_atomic);
+        b.Call(ty.u32(), f_nonatomic, wg_atomic, wg_nonatomic);
+        b.Call(ty.u32(), f_nonatomic, wg_nonatomic, wg_atomic);
+        b.Call(ty.u32(), f_nonatomic, wg_nonatomic, wg_nonatomic);
+
+        // Duplicate the calls to make sure the functions don't duplicate
+        b.Call(ty.u32(), f_nonatomic, wg_atomic, wg_atomic);
+        b.Call(ty.u32(), f_nonatomic, wg_atomic, wg_nonatomic);
+        b.Call(ty.u32(), f_nonatomic, wg_nonatomic, wg_atomic);
+        b.Call(ty.u32(), f_nonatomic, wg_nonatomic, wg_nonatomic);
+        b.Return(main);
+    });
+
+    auto* src = R"(
+$B1: {  # root
+  %wg_atomic:ptr<workgroup, u32, read_write> = var undef
+  %wg_nonatomic:ptr<workgroup, u32, read_write> = var undef
+}
+
+%f_atomic = func(%param:ptr<workgroup, u32, read_write>):u32 {
+  $B2: {
+    %5:u32 = spirv.atomic_load %param, 1u, 0u
+    ret %5
+  }
+}
+%f_nonatomic = func(%param1:ptr<workgroup, u32, read_write>, %param2:ptr<workgroup, u32, read_write>):u32 {
+  $B3: {
+    %9:u32 = load %param1
+    %10:u32 = load %param2
+    %11:u32 = add %9, %10
+    ret %11
+  }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B4: {
+    %13:u32 = call %f_atomic, %wg_atomic
+    %14:u32 = call %f_nonatomic, %wg_atomic, %wg_atomic
+    %15:u32 = call %f_nonatomic, %wg_atomic, %wg_nonatomic
+    %16:u32 = call %f_nonatomic, %wg_nonatomic, %wg_atomic
+    %17:u32 = call %f_nonatomic, %wg_nonatomic, %wg_nonatomic
+    %18:u32 = call %f_nonatomic, %wg_atomic, %wg_atomic
+    %19:u32 = call %f_nonatomic, %wg_atomic, %wg_nonatomic
+    %20:u32 = call %f_nonatomic, %wg_nonatomic, %wg_atomic
+    %21:u32 = call %f_nonatomic, %wg_nonatomic, %wg_nonatomic
+    ret
+  }
+}
+)";
+
+    ASSERT_EQ(src, str());
+    Run(Atomics);
+
+    auto* expect = R"(
+$B1: {  # root
+  %wg_atomic:ptr<workgroup, atomic<u32>, read_write> = var undef
+  %wg_nonatomic:ptr<workgroup, u32, read_write> = var undef
+}
+
+%f_atomic = func(%param:ptr<workgroup, atomic<u32>, read_write>):u32 {
+  $B2: {
+    %5:u32 = atomicLoad %param
+    ret %5
+  }
+}
+%f_nonatomic = func(%param1:ptr<workgroup, u32, read_write>, %param2:ptr<workgroup, u32, read_write>):u32 {
+  $B3: {
+    %9:u32 = load %param1
+    %10:u32 = load %param2
+    %11:u32 = add %9, %10
+    ret %11
+  }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B4: {
+    %13:u32 = call %f_atomic, %wg_atomic
+    %14:u32 = call %f_nonatomic_1, %wg_atomic, %wg_atomic
+    %16:u32 = call %f_nonatomic_2, %wg_atomic, %wg_nonatomic
+    %18:u32 = call %f_nonatomic_3, %wg_nonatomic, %wg_atomic
+    %20:u32 = call %f_nonatomic, %wg_nonatomic, %wg_nonatomic
+    %21:u32 = call %f_nonatomic_1, %wg_atomic, %wg_atomic
+    %22:u32 = call %f_nonatomic_2, %wg_atomic, %wg_nonatomic
+    %23:u32 = call %f_nonatomic_3, %wg_nonatomic, %wg_atomic
+    %24:u32 = call %f_nonatomic, %wg_nonatomic, %wg_nonatomic
+    ret
+  }
+}
+%f_nonatomic_1 = func(%param1_1:ptr<workgroup, atomic<u32>, read_write>, %param2_1:ptr<workgroup, atomic<u32>, read_write>):u32 {  # %f_nonatomic_1: 'f_nonatomic', %param1_1: 'param1', %param2_1: 'param2'
+  $B5: {
+    %27:u32 = atomicLoad %param1_1
+    %28:u32 = atomicLoad %param2_1
+    %29:u32 = add %27, %28
+    ret %29
+  }
+}
+%f_nonatomic_2 = func(%param1_2:ptr<workgroup, atomic<u32>, read_write>, %param2_2:ptr<workgroup, u32, read_write>):u32 {  # %f_nonatomic_2: 'f_nonatomic', %param1_2: 'param1', %param2_2: 'param2'
+  $B6: {
+    %32:u32 = atomicLoad %param1_2
+    %33:u32 = load %param2_2
+    %34:u32 = add %32, %33
+    ret %34
+  }
+}
+%f_nonatomic_3 = func(%param1_3:ptr<workgroup, u32, read_write>, %param2_3:ptr<workgroup, atomic<u32>, read_write>):u32 {  # %f_nonatomic_3: 'f_nonatomic', %param1_3: 'param1', %param2_3: 'param2'
+  $B7: {
+    %37:u32 = load %param1_3
+    %38:u32 = atomicLoad %param2_3
+    %39:u32 = add %37, %38
+    ret %39
+  }
+}
+)";
+    ASSERT_EQ(expect, str());
+}
+
 }  // namespace
 }  // namespace tint::spirv::reader::lower