[spirv-reader][ir] Convert spirv atomic_store builtin.

Add initial atomic transform support and convert the
`spirv.atomic_store` builtin to a core `atomicStore` WGSL instruction.

Bug: 391487430
Change-Id: I03658d0d3f2e82e8717d9a18e4e85805da68ebeb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/233054
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/lower/atomics.cc b/src/tint/lang/spirv/reader/lower/atomics.cc
index c8dbe9c..ac91882 100644
--- a/src/tint/lang/spirv/reader/lower/atomics.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics.cc
@@ -30,6 +30,9 @@
 #include "src/tint/lang/core/ir/builder.h"
 #include "src/tint/lang/core/ir/module.h"
 #include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/spirv/ir/builtin_call.h"
+#include "src/tint/utils/containers/hashset.h"
+#include "src/tint/utils/containers/vector.h"
 
 namespace tint::spirv::reader::lower {
 namespace {
@@ -48,8 +51,103 @@
     /// The type manager.
     core::type::Manager& ty{ir.Types()};
 
+    /// The `ir::Value`s to be converted to atomics
+    Vector<core::ir::Value*, 8> values_to_convert_{};
+
+    /// The `ir::Value`s which have been converted
+    Hashset<core::ir::Value*, 8> converted_{};
+
     /// Process the module.
-    void Process() {}
+    void Process() {
+        Vector<spirv::ir::BuiltinCall*, 4> builtin_worklist;
+        for (auto* inst : ir.Instructions()) {
+            if (auto* builtin = inst->As<spirv::ir::BuiltinCall>()) {
+                builtin_worklist.Push(builtin);
+            }
+        }
+
+        for (auto* builtin : builtin_worklist) {
+            switch (builtin->Func()) {
+                case spirv::BuiltinFn::kAtomicLoad:
+                    break;
+                case spirv::BuiltinFn::kAtomicStore:
+                    AtomicStore(builtin);
+                    break;
+                case spirv::BuiltinFn::kAtomicExchange:
+                case spirv::BuiltinFn::kAtomicCompareExchange:
+                case spirv::BuiltinFn::kAtomicIAdd:
+                case spirv::BuiltinFn::kAtomicISub:
+                case spirv::BuiltinFn::kAtomicSMax:
+                case spirv::BuiltinFn::kAtomicSMin:
+                case spirv::BuiltinFn::kAtomicUMax:
+                case spirv::BuiltinFn::kAtomicUMin:
+                case spirv::BuiltinFn::kAtomicAnd:
+                case spirv::BuiltinFn::kAtomicOr:
+                case spirv::BuiltinFn::kAtomicXor:
+                case spirv::BuiltinFn::kAtomicIIncrement:
+                case spirv::BuiltinFn::kAtomicIDecrement:
+                    break;
+                default:
+                    TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
+            }
+        }
+
+        while (!values_to_convert_.IsEmpty()) {
+            auto* val = values_to_convert_.Pop();
+
+            if (converted_.Add(val)) {
+                ConvertAtomicValue(val);
+            }
+        }
+    }
+
+    void AtomicStore(spirv::ir::BuiltinCall* call) {
+        auto args = call->Args();
+
+        b.InsertBefore(call, [&] {
+            auto* var = args[0];
+            values_to_convert_.Push(var);
+
+            auto* val = args[3];
+            b.CallWithResult(call->DetachResult(), core::BuiltinFn::kAtomicStore, var, val);
+        });
+        call->Destroy();
+    }
+
+    void ConvertAtomicValue(core::ir::Value* val) {
+        auto* res = val->As<core::ir::InstructionResult>();
+        TINT_ASSERT(res);
+
+        auto* atomic_ty = AtomicTypeFor(res->Type());
+        res->SetType(atomic_ty);
+
+        tint::Switch(                                                            //
+            res->Instruction(),                                                  //
+            [&](core::ir::Access* a) { values_to_convert_.Push(a->Object()); },  //
+            [&](core::ir::Var*) {},                                              //
+            TINT_ICE_ON_NO_MATCH);
+    }
+
+    const core::type::Type* AtomicTypeFor(const core::type::Type* orig_ty) {
+        return tint::Switch(
+            orig_ty,  //
+            [&](const core::type::I32*) { return ty.atomic(orig_ty); },
+            [&](const core::type::U32*) { return ty.atomic(orig_ty); },
+            // [&](const core::type::Struct* str) { return ty(Fork(str).name); },
+            [&](const core::type::Array* arr) {
+                if (arr->Count()->Is<core::type::RuntimeArrayCount>()) {
+                    return ty.runtime_array(AtomicTypeFor(arr->ElemType()));
+                }
+                auto count = arr->ConstantCount();
+                TINT_ASSERT(count);
+
+                return ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
+            },
+            [&](const core::type::Pointer* ptr) {
+                return ty.ptr(ptr->AddressSpace(), AtomicTypeFor(ptr->StoreType()), ptr->Access());
+            },
+            TINT_ICE_ON_NO_MATCH);
+    }
 };
 
 }  // namespace
diff --git a/src/tint/lang/spirv/reader/lower/atomics_test.cc b/src/tint/lang/spirv/reader/lower/atomics_test.cc
index 5450a7e..d3bef8e 100644
--- a/src/tint/lang/spirv/reader/lower/atomics_test.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics_test.cc
@@ -38,7 +38,7 @@
 
 using SpirvReader_AtomicsTest = core::ir::transform::TransformTest;
 
-TEST_F(SpirvReader_AtomicsTest, DISABLED_ArrayStore) {
+TEST_F(SpirvReader_AtomicsTest, ArrayStore) {
     auto* f = b.ComputeFunction("main");
 
     core::ir::Var* wg = nullptr;
@@ -48,7 +48,7 @@
     b.Append(f->Block(), [&] {  //
         auto* a = b.Access(ty.ptr<workgroup, u32, read_write>(), wg, 1_i);
         b.Call<spirv::ir::BuiltinCall>(ty.void_(), spirv::BuiltinFn::kAtomicStore, a, 1_u, 0_u,
-                                       0_u);
+                                       1_u);
         b.Return(f);
     });
 
@@ -60,7 +60,7 @@
 %main = @compute @workgroup_size(1u, 1u, 1u) func():void {
   $B2: {
     %3:ptr<workgroup, u32, read_write> = access %wg, 1i
-    %4:void = spirv.atomic_store %3, 1u, 0u, 0u
+    %4:void = spirv.atomic_store %3, 1u, 0u, 1u
     ret
   }
 }
@@ -69,7 +69,17 @@
     Run(Atomics);
 
     auto* expect = R"(
-UNIMPLEMENTED
+$B1: {  # root
+  %wg:ptr<workgroup, array<atomic<u32>, 4>, read_write> = var undef
+}
+
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B2: {
+    %3:ptr<workgroup, atomic<u32>, read_write> = access %wg, 1i
+    %4:void = atomicStore %3, 1u
+    ret
+  }
+}
 )";
     ASSERT_EQ(expect, str());
 }
@@ -112,7 +122,7 @@
     ASSERT_EQ(expect, str());
 }
 
-TEST_F(SpirvReader_AtomicsTest, DISABLED_ArrayNested) {
+TEST_F(SpirvReader_AtomicsTest, ArrayNested) {
     auto* f = b.ComputeFunction("main");
 
     core::ir::Var* wg = nullptr;
@@ -144,7 +154,17 @@
     Run(Atomics);
 
     auto* expect = R"(
-UNIMPLEMENTED
+$B1: {  # root
+  %wg:ptr<workgroup, array<array<array<atomic<u32>, 1>, 2>, 3>, read_write> = var undef
+}
+
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+  $B2: {
+    %3:ptr<workgroup, atomic<u32>, read_write> = access %wg, 2i, 1i, 0i
+    %4:void = atomicStore %3, 1u
+    ret
+  }
+}
 )";
     ASSERT_EQ(expect, str());
 }
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index 0d194d3..28cdce7 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -218,6 +218,23 @@
                 case spirv::BuiltinFn::kOuterProduct:
                     OuterProduct(builtin);
                     break;
+                case spirv::BuiltinFn::kAtomicLoad:
+                case spirv::BuiltinFn::kAtomicStore:
+                case spirv::BuiltinFn::kAtomicExchange:
+                case spirv::BuiltinFn::kAtomicCompareExchange:
+                case spirv::BuiltinFn::kAtomicIAdd:
+                case spirv::BuiltinFn::kAtomicISub:
+                case spirv::BuiltinFn::kAtomicSMax:
+                case spirv::BuiltinFn::kAtomicSMin:
+                case spirv::BuiltinFn::kAtomicUMax:
+                case spirv::BuiltinFn::kAtomicUMin:
+                case spirv::BuiltinFn::kAtomicAnd:
+                case spirv::BuiltinFn::kAtomicOr:
+                case spirv::BuiltinFn::kAtomicXor:
+                case spirv::BuiltinFn::kAtomicIIncrement:
+                case spirv::BuiltinFn::kAtomicIDecrement:
+                    // Ignore Atomics, they'll be handled by the `Atomics` transform.
+                    break;
                 default:
                     TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
             }
diff --git a/src/tint/lang/spirv/reader/lower/lower.cc b/src/tint/lang/spirv/reader/lower/lower.cc
index 86d7241..c5e4fb6 100644
--- a/src/tint/lang/spirv/reader/lower/lower.cc
+++ b/src/tint/lang/spirv/reader/lower/lower.cc
@@ -29,6 +29,7 @@
 
 #include "src/tint/lang/core/ir/transform/remove_terminator_args.h"
 #include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/spirv/reader/lower/atomics.h"
 #include "src/tint/lang/spirv/reader/lower/builtins.h"
 #include "src/tint/lang/spirv/reader/lower/shader_io.h"
 #include "src/tint/lang/spirv/reader/lower/vector_element_pointer.h"
@@ -47,6 +48,7 @@
     RUN_TRANSFORM(lower::VectorElementPointer, mod);
     RUN_TRANSFORM(lower::ShaderIO, mod);
     RUN_TRANSFORM(lower::Builtins, mod);
+    RUN_TRANSFORM(lower::Atomics, mod);
 
     // Remove the terminator args at this point. There are no logical short-circuiting operators in
     // SPIR-V that we will lose track of, all the terminators are for hoisted values. We don't do