writer/msl: Emit helper functions for atomicCompareExchangeWeak

By generating a helper function for these, we can keep the atomic expression pre-statement-free. This can help prevent for-loops from being transformed into while loops.

Change-Id: Id034ea5ea9be601661ddb78db973015d845c420f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57463
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 809ca10..5f82e30 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -55,6 +55,7 @@
 #include "src/sem/void_type.h"
 #include "src/transform/msl.h"
 #include "src/utils/defer.h"
+#include "src/utils/get_or_create.h"
 #include "src/utils/scoped_assignment.h"
 #include "src/writer/float_to_string.h"
 
@@ -91,6 +92,8 @@
   line();
   line() << "using namespace metal;";
 
+  auto helpers_insertion_point = current_buffer_->lines.size();
+
   for (auto* const type_decl : program_->AST().TypeDecls()) {
     if (!type_decl->Is<ast::Alias>()) {
       if (!EmitTypeDecl(TypeOf(type_decl))) {
@@ -137,6 +140,11 @@
     line();
   }
 
+  if (!helpers_.lines.empty()) {
+    current_buffer_->Insert("", helpers_insertion_point++, 0);
+    current_buffer_->Insert(helpers_, helpers_insertion_point++, 0);
+  }
+
   return true;
 }
 
@@ -454,7 +462,7 @@
 bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
                                    ast::CallExpression* expr,
                                    const sem::Intrinsic* intrinsic) {
-  auto call = [&](const char* name) {
+  auto call = [&](const std::string& name, bool append_memory_order_relaxed) {
     out << name;
     {
       ScopedParen sp(out);
@@ -467,84 +475,77 @@
           return false;
         }
       }
-      out << ", memory_order_relaxed";
+      if (append_memory_order_relaxed) {
+        out << ", memory_order_relaxed";
+      }
     }
     return true;
   };
 
   switch (intrinsic->Type()) {
     case sem::IntrinsicType::kAtomicLoad:
-      return call("atomic_load_explicit");
+      return call("atomic_load_explicit", true);
 
     case sem::IntrinsicType::kAtomicStore:
-      return call("atomic_store_explicit");
+      return call("atomic_store_explicit", true);
 
     case sem::IntrinsicType::kAtomicAdd:
-      return call("atomic_fetch_add_explicit");
+      return call("atomic_fetch_add_explicit", true);
 
     case sem::IntrinsicType::kAtomicMax:
-      return call("atomic_fetch_max_explicit");
+      return call("atomic_fetch_max_explicit", true);
 
     case sem::IntrinsicType::kAtomicMin:
-      return call("atomic_fetch_min_explicit");
+      return call("atomic_fetch_min_explicit", true);
 
     case sem::IntrinsicType::kAtomicAnd:
-      return call("atomic_fetch_and_explicit");
+      return call("atomic_fetch_and_explicit", true);
 
     case sem::IntrinsicType::kAtomicOr:
-      return call("atomic_fetch_or_explicit");
+      return call("atomic_fetch_or_explicit", true);
 
     case sem::IntrinsicType::kAtomicXor:
-      return call("atomic_fetch_xor_explicit");
+      return call("atomic_fetch_xor_explicit", true);
 
     case sem::IntrinsicType::kAtomicExchange:
-      return call("atomic_exchange_explicit");
+      return call("atomic_exchange_explicit", true);
 
     case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
-      auto* target = expr->params()[0];
-      auto* compare_value = expr->params()[1];
-      auto* value = expr->params()[2];
+      auto* ptr_ty = TypeOf(expr->params()[0])->UnwrapRef()->As<sem::Pointer>();
+      auto sc = ptr_ty->StorageClass();
 
-      auto prev_value = UniqueIdentifier("prev_value");
-      auto matched = UniqueIdentifier("matched");
+      auto func = utils::GetOrCreate(
+          atomicCompareExchangeWeak_, sc, [&]() -> std::string {
+            auto name = UniqueIdentifier("atomicCompareExchangeWeak");
+            auto& buf = helpers_;
 
-      {  // prev_value = <compare_value>;
-        auto pre = line();
-        if (!EmitType(pre, TypeOf(value), "")) {
-          return false;
-        }
-        pre << " " << prev_value << " = ";
-        if (!EmitExpression(pre, compare_value)) {
-          return false;
-        }
-        pre << ";";
-      }
+            line(&buf) << "template <typename A, typename T>";
+            {
+              auto f = line(&buf);
+              f << "vec<T, 2> " << name << "(";
+              if (!EmitStorageClass(f, sc)) {
+                return "";
+              }
+              f << " A* atomic, T compare, T value) {";
+            }
 
-      {  // bool matched = atomic_compare_exchange_weak_explicit(
-         //   target, &got, <value>, memory_order_relaxed, memory_order_relaxed)
-        auto pre = line();
-        pre << "bool " << matched << " = atomic_compare_exchange_weak_explicit";
-        {
-          ScopedParen sp(pre);
-          if (!EmitExpression(pre, target)) {
-            return false;
-          }
-          pre << ", &" << prev_value << ", ";
-          if (!EmitExpression(pre, value)) {
-            return false;
-          }
-          pre << ", memory_order_relaxed, memory_order_relaxed";
-        }
-        pre << ";";
-      }
+            buf.IncrementIndent();
+            TINT_DEFER({
+              buf.DecrementIndent();
+              line(&buf) << "}";
+              line(&buf);
+            });
 
-      {  // [u]int2(got, matched)
-        if (!EmitType(out, TypeOf(expr), "")) {
-          return false;
-        }
-        out << "(" << prev_value << ", " << matched << ")";
-      }
-      return true;
+            line(&buf) << "T prev_value = compare;";
+            line(&buf) << "bool matched = "
+                          "atomic_compare_exchange_weak_explicit(atomic, "
+                          "&prev_value, value, memory_order_relaxed, "
+                          "memory_order_relaxed);";
+            line(&buf) << "return {prev_value, matched};";
+            return name;
+          });
+
+      return call(func, false);
     }
 
     default:
@@ -1867,24 +1868,10 @@
   }
 
   if (auto* ptr = type->As<sem::Pointer>()) {
-    switch (ptr->StorageClass()) {
-      case ast::StorageClass::kFunction:
-      case ast::StorageClass::kPrivate:
-      case ast::StorageClass::kUniformConstant:
-        out << "thread ";
-        break;
-      case ast::StorageClass::kWorkgroup:
-        out << "threadgroup ";
-        break;
-      case ast::StorageClass::kStorage:
-        out << "device ";
-        break;
-      case ast::StorageClass::kUniform:
-        out << "constant ";
-        break;
-      default:
-        TINT_ICE(Writer, diagnostics_) << "unhandled storage class for pointer";
+    if (!EmitStorageClass(out, ptr->StorageClass())) {
+      return false;
     }
+    out << " ";
     if (ptr->StoreType()->Is<sem::Array>()) {
       std::string inner = "(*" + name + ")";
       if (!EmitType(out, ptr->StoreType(), inner)) {
@@ -2004,6 +1991,29 @@
   return false;
 }
 
+bool GeneratorImpl::EmitStorageClass(std::ostream& out, ast::StorageClass sc) {
+  switch (sc) {
+    case ast::StorageClass::kFunction:
+    case ast::StorageClass::kPrivate:
+    case ast::StorageClass::kUniformConstant:
+      out << "thread";
+      return true;
+    case ast::StorageClass::kWorkgroup:
+      out << "threadgroup";
+      return true;
+    case ast::StorageClass::kStorage:
+      out << "device";
+      return true;
+    case ast::StorageClass::kUniform:
+      out << "constant";
+      return true;
+    default:
+      break;
+  }
+  TINT_ICE(Writer, diagnostics_) << "unhandled storage class: " << sc;
+  return false;
+}
+
 bool GeneratorImpl::EmitPackedType(std::ostream& out,
                                    const sem::Type* type,
                                    const std::string& name) {
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 1495178..49333e3 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -16,6 +16,7 @@
 #define SRC_WRITER_MSL_GENERATOR_IMPL_H_
 
 #include <string>
+#include <unordered_map>
 
 #include "src/ast/array_accessor_expression.h"
 #include "src/ast/assignment_statement.h"
@@ -218,6 +219,11 @@
   bool EmitType(std::ostream& out,
                 const sem::Type* type,
                 const std::string& name);
+  /// Handles generating a storage class
+  /// @param out the output of the type stream
+  /// @param sc the storage class to generate
+  /// @returns true if the storage class is emitted
+  bool EmitStorageClass(std::ostream& out, ast::StorageClass sc);
   /// Handles generating an MSL-packed storage type.
   /// If the type does not have a packed form, the standard non-packed form is
   /// emitted.
@@ -282,11 +288,20 @@
     uint32_t align;
   };
 
+  TextBuffer helpers_;  // Helper functions emitted at the top of the output
+
   /// @returns the MSL packed type size and alignment in bytes for the given
   /// type.
   SizeAndAlign MslPackedTypeSizeAndAlign(const sem::Type* ty);
 
+  using StorageClassToString =
+      std::unordered_map<ast::StorageClass, std::string>;
+
   std::function<bool()> emit_continuing_;
+
+  /// Name of atomicCompareExchangeWeak() helper for the given pointer storage
+  /// class.
+  StorageClassToString atomicCompareExchangeWeak_;
 };
 
 }  // namespace msl
diff --git a/src/writer/msl/generator_impl_loop_test.cc b/src/writer/msl/generator_impl_loop_test.cc
index a6ae411..b01d55b 100644
--- a/src/writer/msl/generator_impl_loop_test.cc
+++ b/src/writer/msl/generator_impl_loop_test.cc
@@ -185,9 +185,8 @@
   //   return;
   // }
   Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
-  auto* multi_stmt = Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2);
-  auto* f = For(Decl(Var("b", nullptr, multi_stmt)), nullptr, nullptr,
-                Block(Return()));
+  auto* multi_stmt = Block(Ignore(1), Ignore(2));
+  auto* f = For(multi_stmt, nullptr, nullptr, Block(Return()));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -196,9 +195,10 @@
 
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  {
-    int prev_value = 1;
-    bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
-    int2 b = int2(prev_value, matched);
+    {
+      (void) 1;
+      (void) 2;
+    }
     for(; ; ) {
       return;
     }
@@ -225,35 +225,6 @@
 )");
 }
 
-TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCond) {
-  // var<workgroup> a : atomic<i32>;
-  // for(; atomicCompareExchangeWeak(&a, 1, 2).x == 0; ) {
-  //   return;
-  // }
-
-  Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
-  auto* multi_stmt = create<ast::BinaryExpression>(
-      ast::BinaryOp::kEqual,
-      MemberAccessor(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2),
-                     "x"),
-      Expr(0));
-  auto* f = For(nullptr, multi_stmt, nullptr, Block(Return()));
-  WrapInFunction(f);
-
-  GeneratorImpl& gen = Build();
-
-  gen.increment_indent();
-
-  ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
-  EXPECT_EQ(gen.result(), R"(  while (true) {
-    int prev_value = 1;
-    bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
-    if (!((int2(prev_value, matched).x == 0))) { break; }
-    return;
-  }
-)");
-}
-
 TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) {
   // for(; ; i = i + 1) {
   //   return;
@@ -276,13 +247,12 @@
 
 TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) {
   // var<workgroup> a : atomic<i32>;
-  // for(; ; ignore(atomicCompareExchangeWeak(&a, 1, 2))) {
+  // for(; ; { ignore(1); ignore(2); }) {
   //   return;
   // }
 
   Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
-  auto* multi_stmt =
-      Ignore(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2));
+  auto* multi_stmt = Block(Ignore(1), Ignore(2));
   auto* f = For(nullptr, nullptr, multi_stmt, Block(Return()));
   WrapInFunction(f);
 
@@ -293,9 +263,10 @@
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  while (true) {
     return;
-    int prev_value = 1;
-    bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
-    (void) int2(prev_value, matched);
+    {
+      (void) 1;
+      (void) 2;
+    }
   }
 )");
 }
@@ -322,22 +293,13 @@
 
 TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
   // var<workgroup> a : atomic<i32>;
-  // for(var b = atomicCompareExchangeWeak(&a, 1, 2);
-  //     atomicCompareExchangeWeak(&a, 1, 2).x == 0;
-  //     ignore(atomicCompareExchangeWeak(&a, 1, 2))) {
+  // for({ ignore(1); ignore(2); }; true; { ignore(3); ignore(4); }) {
   //   return;
   // }
   Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
-  auto* multi_stmt_a = Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2);
-  auto* multi_stmt_b = create<ast::BinaryExpression>(
-      ast::BinaryOp::kEqual,
-      MemberAccessor(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2),
-                     "x"),
-      Expr(0));
-  auto* multi_stmt_c =
-      Ignore(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2));
-  auto* f = For(Decl(Var("b", nullptr, multi_stmt_a)), multi_stmt_b,
-                multi_stmt_c, Block(Return()));
+  auto* multi_stmt_a = Block(Ignore(1), Ignore(2));
+  auto* multi_stmt_b = Block(Ignore(3), Ignore(4));
+  auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b, Block(Return()));
   WrapInFunction(f);
 
   GeneratorImpl& gen = Build();
@@ -346,17 +308,17 @@
 
   ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
   EXPECT_EQ(gen.result(), R"(  {
-    int prev_value = 1;
-    bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
-    int2 b = int2(prev_value, matched);
+    {
+      (void) 1;
+      (void) 2;
+    }
     while (true) {
-      int prev_value_1 = 1;
-      bool matched_1 = atomic_compare_exchange_weak_explicit(&(a), &prev_value_1, 2, memory_order_relaxed, memory_order_relaxed);
-      if (!((int2(prev_value_1, matched_1).x == 0))) { break; }
+      if (!(true)) { break; }
       return;
-      int prev_value_2 = 1;
-      bool matched_2 = atomic_compare_exchange_weak_explicit(&(a), &prev_value_2, 2, memory_order_relaxed, memory_order_relaxed);
-      (void) int2(prev_value_2, matched_2);
+      {
+        (void) 3;
+        (void) 4;
+      }
     }
   }
 )");
diff --git a/test/intrinsics/gen/atomicCompareExchangeWeak/12871c.wgsl.expected.msl b/test/intrinsics/gen/atomicCompareExchangeWeak/12871c.wgsl.expected.msl
index 56d8090..37ea780 100644
--- a/test/intrinsics/gen/atomicCompareExchangeWeak/12871c.wgsl.expected.msl
+++ b/test/intrinsics/gen/atomicCompareExchangeWeak/12871c.wgsl.expected.msl
@@ -1,14 +1,20 @@
 #include <metal_stdlib>
 
 using namespace metal;
+
+template <typename A, typename T>
+vec<T, 2> atomicCompareExchangeWeak_1(device A* atomic, T compare, T value) {
+  T prev_value = compare;
+  bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
+  return {prev_value, matched};
+}
+
 struct SB_RW {
   /* 0x0000 */ atomic_int arg_0;
 };
 
 void atomicCompareExchangeWeak_12871c(device SB_RW& sb_rw) {
-  int prev_value = 1;
-  bool matched = atomic_compare_exchange_weak_explicit(&(sb_rw.arg_0), &prev_value, 1, memory_order_relaxed, memory_order_relaxed);
-  int2 res = int2(prev_value, matched);
+  int2 res = atomicCompareExchangeWeak_1(&(sb_rw.arg_0), 1, 1);
 }
 
 fragment void fragment_main(device SB_RW& sb_rw [[buffer(0)]]) {
diff --git a/test/intrinsics/gen/atomicCompareExchangeWeak/6673da.wgsl.expected.msl b/test/intrinsics/gen/atomicCompareExchangeWeak/6673da.wgsl.expected.msl
index c4999ee..30fe03f 100644
--- a/test/intrinsics/gen/atomicCompareExchangeWeak/6673da.wgsl.expected.msl
+++ b/test/intrinsics/gen/atomicCompareExchangeWeak/6673da.wgsl.expected.msl
@@ -1,14 +1,20 @@
 #include <metal_stdlib>
 
 using namespace metal;
+
+template <typename A, typename T>
+vec<T, 2> atomicCompareExchangeWeak_1(device A* atomic, T compare, T value) {
+  T prev_value = compare;
+  bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
+  return {prev_value, matched};
+}
+
 struct SB_RW {
   /* 0x0000 */ atomic_uint arg_0;
 };
 
 void atomicCompareExchangeWeak_6673da(device SB_RW& sb_rw) {
-  uint prev_value = 1u;
-  bool matched = atomic_compare_exchange_weak_explicit(&(sb_rw.arg_0), &prev_value, 1u, memory_order_relaxed, memory_order_relaxed);
-  uint2 res = uint2(prev_value, matched);
+  uint2 res = atomicCompareExchangeWeak_1(&(sb_rw.arg_0), 1u, 1u);
 }
 
 fragment void fragment_main(device SB_RW& sb_rw [[buffer(0)]]) {
diff --git a/test/intrinsics/gen/atomicCompareExchangeWeak/89ea3b.wgsl.expected.msl b/test/intrinsics/gen/atomicCompareExchangeWeak/89ea3b.wgsl.expected.msl
index 036e9bf..1350db4 100644
--- a/test/intrinsics/gen/atomicCompareExchangeWeak/89ea3b.wgsl.expected.msl
+++ b/test/intrinsics/gen/atomicCompareExchangeWeak/89ea3b.wgsl.expected.msl
@@ -1,10 +1,16 @@
 #include <metal_stdlib>
 
 using namespace metal;
+
+template <typename A, typename T>
+vec<T, 2> atomicCompareExchangeWeak_1(threadgroup A* atomic, T compare, T value) {
+  T prev_value = compare;
+  bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
+  return {prev_value, matched};
+}
+
 void atomicCompareExchangeWeak_89ea3b(threadgroup atomic_int* const tint_symbol_1) {
-  int prev_value = 1;
-  bool matched = atomic_compare_exchange_weak_explicit(&(*(tint_symbol_1)), &prev_value, 1, memory_order_relaxed, memory_order_relaxed);
-  int2 res = int2(prev_value, matched);
+  int2 res = atomicCompareExchangeWeak_1(&(*(tint_symbol_1)), 1, 1);
 }
 
 kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {
diff --git a/test/intrinsics/gen/atomicCompareExchangeWeak/b2ab2c.wgsl.expected.msl b/test/intrinsics/gen/atomicCompareExchangeWeak/b2ab2c.wgsl.expected.msl
index 45d921c..99c9ab1 100644
--- a/test/intrinsics/gen/atomicCompareExchangeWeak/b2ab2c.wgsl.expected.msl
+++ b/test/intrinsics/gen/atomicCompareExchangeWeak/b2ab2c.wgsl.expected.msl
@@ -1,10 +1,16 @@
 #include <metal_stdlib>
 
 using namespace metal;
+
+template <typename A, typename T>
+vec<T, 2> atomicCompareExchangeWeak_1(threadgroup A* atomic, T compare, T value) {
+  T prev_value = compare;
+  bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
+  return {prev_value, matched};
+}
+
 void atomicCompareExchangeWeak_b2ab2c(threadgroup atomic_uint* const tint_symbol_1) {
-  uint prev_value = 1u;
-  bool matched = atomic_compare_exchange_weak_explicit(&(*(tint_symbol_1)), &prev_value, 1u, memory_order_relaxed, memory_order_relaxed);
-  uint2 res = uint2(prev_value, matched);
+  uint2 res = atomicCompareExchangeWeak_1(&(*(tint_symbol_1)), 1u, 1u);
 }
 
 kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {