tint: spirv reader: detect and replace loads and stores of atomic variables with atomicLoad/Store

Bug: tint:1441
Change-Id: Iee89cb87ca063d8a98ff8ad789ba14dee65c036a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95140
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/transform/spirv_atomic.cc b/src/tint/transform/spirv_atomic.cc
index d820a87..dc5b35c 100644
--- a/src/tint/transform/spirv_atomic.cc
+++ b/src/tint/transform/spirv_atomic.cc
@@ -135,6 +135,10 @@
             });
         }
 
+        // Replace assignments and decls from atomic variables with atomicLoads, and assignments to
+        // atomic variables with atomicStores.
+        ReplaceLoadsAndStores();
+
         ctx.Clone();
     }
 
@@ -200,6 +204,70 @@
                 return nullptr;
             });
     }
+
+    void ReplaceLoadsAndStores() {
+        // Returns true if 'e' is a reference to an atomic variable or struct member
+        auto is_ref_to_atomic_var = [&](const sem::Expression* e) {
+            if (tint::Is<sem::Reference>(e->Type()) && e->SourceVariable() &&
+                (atomic_variables.count(e->SourceVariable()) != 0)) {
+                // If it's a struct member, make sure it's one we marked as atomic
+                if (auto* ma = e->As<sem::StructMemberAccess>()) {
+                    auto it = forked_structs.find(ma->Member()->Struct()->Declaration());
+                    if (it != forked_structs.end()) {
+                        auto& forked = it->second;
+                        return forked.atomic_members.count(ma->Member()->Index()) != 0;
+                    }
+                }
+                return true;
+            }
+            return false;
+        };
+
+        // Look for loads and stores via assignments and decls of atomic variables we've collected
+        // so far, and replace them with atomicLoad and atomicStore.
+        for (auto* atomic_var : atomic_variables) {
+            for (auto* vu : atomic_var->Users()) {
+                Switch(
+                    vu->Stmt()->Declaration(),
+                    [&](const ast::AssignmentStatement* assign) {
+                        auto* sem_lhs = ctx.src->Sem().Get(assign->lhs);
+                        if (is_ref_to_atomic_var(sem_lhs)) {
+                            ctx.Replace(assign, [=] {
+                                auto* lhs = ctx.CloneWithoutTransform(assign->lhs);
+                                auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
+                                auto* call = b.Call(sem::str(sem::BuiltinType::kAtomicStore),
+                                                    b.AddressOf(lhs), rhs);
+                                return b.CallStmt(call);
+                            });
+                            return;
+                        }
+
+                        auto sem_rhs = ctx.src->Sem().Get(assign->rhs);
+                        if (is_ref_to_atomic_var(sem_rhs)) {
+                            ctx.Replace(assign->rhs, [=] {
+                                auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
+                                return b.Call(sem::str(sem::BuiltinType::kAtomicLoad),
+                                              b.AddressOf(rhs));
+                            });
+                            return;
+                        }
+                    },
+                    [&](const ast::VariableDeclStatement* decl) {
+                        auto* var = decl->variable;
+                        if (auto* sem_ctor = ctx.src->Sem().Get(var->constructor)) {
+                            if (is_ref_to_atomic_var(sem_ctor)) {
+                                ctx.Replace(var->constructor, [=] {
+                                    auto* rhs = ctx.CloneWithoutTransform(var->constructor);
+                                    return b.Call(sem::str(sem::BuiltinType::kAtomicLoad),
+                                                  b.AddressOf(rhs));
+                                });
+                                return;
+                            }
+                        }
+                    });
+            }
+        }
+    }
 };
 
 SpirvAtomic::SpirvAtomic() = default;
diff --git a/src/tint/transform/spirv_atomic_test.cc b/src/tint/transform/spirv_atomic_test.cc
index 804193e..7f07d06 100644
--- a/src/tint/transform/spirv_atomic_test.cc
+++ b/src/tint/transform/spirv_atomic_test.cc
@@ -1018,5 +1018,366 @@
     EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_Scaler) {
+    auto* src = R"(
+var<workgroup> wg : u32;
+
+fn f() {
+  stub_atomicAdd_u32(wg, 1u);
+
+  wg = 0u;
+  let a = wg;
+  var b : u32;
+  b = wg;
+}
+)";
+
+    auto* expect = R"(
+var<workgroup> wg : atomic<u32>;
+
+fn f() {
+  atomicAdd(&(wg), 1u);
+  atomicStore(&(wg), 0u);
+  let a = atomicLoad(&(wg));
+  var b : u32;
+  b = atomicLoad(&(wg));
+}
+)";
+
+    auto got = Run(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_Struct) {
+    auto* src = R"(
+struct S {
+  a : u32,
+}
+
+var<workgroup> wg : S;
+
+fn f() {
+  stub_atomicAdd_u32(wg.a, 1u);
+
+  wg.a = 0u;
+  let a = wg.a;
+  var b : u32;
+  b = wg.a;
+}
+)";
+
+    auto* expect = R"(
+struct S_atomic {
+  a : atomic<u32>,
+}
+
+struct S {
+  a : u32,
+}
+
+var<workgroup> wg : S_atomic;
+
+fn f() {
+  atomicAdd(&(wg.a), 1u);
+  atomicStore(&(wg.a), 0u);
+  let a = atomicLoad(&(wg.a));
+  var b : u32;
+  b = atomicLoad(&(wg.a));
+}
+)";
+
+    auto got = Run(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_NestedStruct) {
+    auto* src = R"(
+struct S0 {
+  a : u32,
+}
+
+struct S1 {
+  s0 : S0
+}
+
+var<workgroup> wg : S1;
+
+fn f() {
+  stub_atomicAdd_u32(wg.s0.a, 1u);
+
+  wg.s0.a = 0u;
+  let a = wg.s0.a;
+  var b : u32;
+  b = wg.s0.a;
+}
+)";
+
+    auto* expect = R"(
+struct S0_atomic {
+  a : atomic<u32>,
+}
+
+struct S0 {
+  a : u32,
+}
+
+struct S1_atomic {
+  s0 : S0_atomic,
+}
+
+struct S1 {
+  s0 : S0,
+}
+
+var<workgroup> wg : S1_atomic;
+
+fn f() {
+  atomicAdd(&(wg.s0.a), 1u);
+  atomicStore(&(wg.s0.a), 0u);
+  let a = atomicLoad(&(wg.s0.a));
+  var b : u32;
+  b = atomicLoad(&(wg.s0.a));
+}
+)";
+
+    auto got = Run(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_StructMultipleAtomics) {
+    auto* src = R"(
+struct S {
+  a : u32,
+  b : u32,
+  c : u32,
+}
+
+var<workgroup> wg : S;
+
+fn f() {
+  stub_atomicAdd_u32(wg.a, 1u);
+  stub_atomicAdd_u32(wg.b, 1u);
+
+  wg.a = 0u;
+  let a = wg.a;
+  var b : u32;
+  b = wg.a;
+
+  wg.b = 0u;
+  let c = wg.b;
+  var d : u32;
+  d = wg.b;
+
+  wg.c = 0u;
+  let e = wg.c;
+  var f : u32;
+  f = wg.c;
+}
+)";
+
+    auto* expect = R"(
+struct S_atomic {
+  a : atomic<u32>,
+  b : atomic<u32>,
+  c : u32,
+}
+
+struct S {
+  a : u32,
+  b : u32,
+  c : u32,
+}
+
+var<workgroup> wg : S_atomic;
+
+fn f() {
+  atomicAdd(&(wg.a), 1u);
+  atomicAdd(&(wg.b), 1u);
+  atomicStore(&(wg.a), 0u);
+  let a = atomicLoad(&(wg.a));
+  var b : u32;
+  b = atomicLoad(&(wg.a));
+  atomicStore(&(wg.b), 0u);
+  let c = atomicLoad(&(wg.b));
+  var d : u32;
+  d = atomicLoad(&(wg.b));
+  wg.c = 0u;
+  let e = wg.c;
+  var f : u32;
+  f = wg.c;
+}
+)";
+
+    auto got = Run(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_ArrayOfScalar) {
+    auto* src = R"(
+var<workgroup> wg : array<u32, 4>;
+
+fn f() {
+  stub_atomicAdd_u32(wg[1], 1u);
+
+  wg[1] = 0u;
+  let a = wg[1];
+  var b : u32;
+  b = wg[1];
+}
+)";
+
+    auto* expect = R"(
+var<workgroup> wg : array<atomic<u32>, 4u>;
+
+fn f() {
+  atomicAdd(&(wg[1]), 1u);
+  atomicStore(&(wg[1]), 0u);
+  let a = atomicLoad(&(wg[1]));
+  var b : u32;
+  b = atomicLoad(&(wg[1]));
+}
+)";
+
+    auto got = Run(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_ArrayOfStruct) {
+    auto* src = R"(
+struct S {
+  a : u32,
+}
+
+var<workgroup> wg : array<S, 4>;
+
+fn f() {
+  stub_atomicAdd_u32(wg[1].a, 1u);
+
+  wg[1].a = 0u;
+  let a = wg[1].a;
+  var b : u32;
+  b = wg[1].a;
+}
+)";
+
+    auto* expect = R"(
+struct S_atomic {
+  a : atomic<u32>,
+}
+
+struct S {
+  a : u32,
+}
+
+var<workgroup> wg : array<S_atomic, 4u>;
+
+fn f() {
+  atomicAdd(&(wg[1].a), 1u);
+  atomicStore(&(wg[1].a), 0u);
+  let a = atomicLoad(&(wg[1].a));
+  var b : u32;
+  b = atomicLoad(&(wg[1].a));
+}
+)";
+
+    auto got = Run(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_StructOfArray) {
+    auto* src = R"(
+struct S {
+  a : array<u32>,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn f() {
+  stub_atomicAdd_u32(s.a[4], 1u);
+
+  s.a[4] = 0u;
+  let a = s.a[4];
+  var b : u32;
+  b = s.a[4];
+}
+)";
+
+    auto* expect = R"(
+struct S_atomic {
+  a : array<atomic<u32>>,
+}
+
+struct S {
+  a : array<u32>,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S_atomic;
+
+fn f() {
+  atomicAdd(&(s.a[4]), 1u);
+  atomicStore(&(s.a[4]), 0u);
+  let a = atomicLoad(&(s.a[4]));
+  var b : u32;
+  b = atomicLoad(&(s.a[4]));
+}
+)";
+
+    auto got = Run(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvAtomicTest, ReplaceAssignsAndDecls_ViaPtrLet) {
+    auto* src = R"(
+struct S {
+  i : u32,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S;
+
+fn f() {
+  let p0 = &(s);
+  let p1 : ptr<storage, u32, read_write> = &((*(p0)).i);
+  stub_atomicAdd_u32(*p1, 1u);
+
+  *p1 = 0u;
+  let a = *p1;
+  var b : u32;
+  b = *p1;
+}
+)";
+
+    auto* expect = R"(
+struct S_atomic {
+  i : atomic<u32>,
+}
+
+struct S {
+  i : u32,
+}
+
+@group(0) @binding(1) var<storage, read_write> s : S_atomic;
+
+fn f() {
+  let p0 = &(s);
+  let p1 : ptr<storage, atomic<u32>, read_write> = &((*(p0)).i);
+  atomicAdd(&(*(p1)), 1u);
+  atomicStore(&(*(p1)), 0u);
+  let a = atomicLoad(&(*(p1)));
+  var b : u32;
+  b = atomicLoad(&(*(p1)));
+}
+)";
+
+    auto got = Run(src);
+
+    EXPECT_EQ(expect, str(got));
+}
 }  // namespace
 }  // namespace tint::transform