[tint][IrToProgram] Reconstruct workgroupUniformLoad

Required to prevent uniformity analysis errors

Change-Id: I9bdc9cd6a83c9e443359bb23b0f77d2f2b40af58
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/154502
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/wgsl/writer/raise/raise.cc b/src/tint/lang/wgsl/writer/raise/raise.cc
index d0ab5da..53ffb6c 100644
--- a/src/tint/lang/wgsl/writer/raise/raise.cc
+++ b/src/tint/lang/wgsl/writer/raise/raise.cc
@@ -18,6 +18,8 @@
 
 #include "src/tint/lang/core/builtin_fn.h"
 #include "src/tint/lang/core/ir/core_builtin_call.h"
+#include "src/tint/lang/core/ir/load.h"
+#include "src/tint/lang/core/type/pointer.h"
 #include "src/tint/lang/wgsl/builtin_fn.h"
 #include "src/tint/lang/wgsl/ir/builtin_call.h"
 
@@ -151,17 +153,66 @@
     }
 }
 
+void ReplaceBuiltinFnCall(core::ir::Module& mod, core::ir::CoreBuiltinCall* call) {
+    Vector<core::ir::Value*, 8> args(call->Args());
+    auto* replacement = mod.instructions.Create<wgsl::ir::BuiltinCall>(
+        call->Result(), Convert(call->Func()), std::move(args));
+    call->ReplaceWith(replacement);
+    call->ClearResults();
+    call->Destroy();
+}
+
+void ReplaceWorkgroupBarrier(core::ir::Module& mod, core::ir::CoreBuiltinCall* call) {
+    // Pattern match:
+    //    call workgroupBarrier
+    //    %value = load &ptr
+    //    call workgroupBarrier
+    // And replace with:
+    //    %value = call workgroupUniformLoad %ptr
+
+    auto* load = As<core::ir::Load>(call->next);
+    if (!load || load->From()->Type()->As<core::type::Pointer>()->AddressSpace() !=
+                     core::AddressSpace::kWorkgroup) {
+        // No match
+        ReplaceBuiltinFnCall(mod, call);
+        return;
+    }
+
+    auto* post_load = As<core::ir::CoreBuiltinCall>(load->next);
+    if (!post_load || post_load->Func() != core::BuiltinFn::kWorkgroupBarrier) {
+        // No match
+        ReplaceBuiltinFnCall(mod, call);
+        return;
+    }
+
+    // Remove both calls to workgroupBarrier
+    post_load->Destroy();
+    call->Destroy();
+
+    // Replace load with workgroupUniformLoad
+    auto* replacement = mod.instructions.Create<wgsl::ir::BuiltinCall>(
+        load->Result(), wgsl::BuiltinFn::kWorkgroupUniformLoad, Vector{load->From()});
+    load->ReplaceWith(replacement);
+    load->ClearResults();
+    load->Destroy();
+}
+
 }  // namespace
 
 Result<SuccessType> Raise(core::ir::Module& mod) {
     for (auto* inst : mod.instructions.Objects()) {
+        if (!inst->Alive()) {
+            continue;
+        }
         if (auto* call = inst->As<core::ir::CoreBuiltinCall>()) {
-            Vector<core::ir::Value*, 8> args(call->Args());
-            auto* replacement = mod.instructions.Create<wgsl::ir::BuiltinCall>(
-                call->Result(), Convert(call->Func()), std::move(args));
-            call->ReplaceWith(replacement);
-            call->ClearResults();
-            call->Destroy();
+            switch (call->Func()) {
+                case core::BuiltinFn::kWorkgroupBarrier:
+                    ReplaceWorkgroupBarrier(mod, call);
+                    break;
+                default:
+                    ReplaceBuiltinFnCall(mod, call);
+                    break;
+            }
         }
     }
     return Success;
diff --git a/src/tint/lang/wgsl/writer/raise/raise_test.cc b/src/tint/lang/wgsl/writer/raise/raise_test.cc
index a4f9b05..6928dd4 100644
--- a/src/tint/lang/wgsl/writer/raise/raise_test.cc
+++ b/src/tint/lang/wgsl/writer/raise/raise_test.cc
@@ -57,5 +57,100 @@
     EXPECT_EQ(expect, str());
 }
 
+TEST_F(WgslWriter_RaiseTest, WorkgroupBarrier) {
+    auto* W = b.Var<workgroup, i32, read_write>("W");
+    b.ir.root_block->Append(W);
+    auto* f = b.Function("f", ty.i32());
+    b.Append(f->Block(), [&] {  //
+        b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+        auto* load = b.Load(W);
+        b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+        b.Return(f, load);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %W:ptr<workgroup, i32, read_write> = var
+}
+
+%f = func():i32 -> %b2 {
+  %b2 = block {
+    %3:void = workgroupBarrier
+    %4:i32 = load %W
+    %5:void = workgroupBarrier
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %W:ptr<workgroup, i32, read_write> = var
+}
+
+%f = func():i32 -> %b2 {
+  %b2 = block {
+    %3:i32 = wgsl.workgroupUniformLoad %W
+    ret %3
+  }
+}
+)";
+
+    Run(Raise);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(WgslWriter_RaiseTest, WorkgroupBarrier_NoMatch) {
+    auto* W = b.Var<workgroup, i32, read_write>("W");
+    b.ir.root_block->Append(W);
+    auto* f = b.Function("f", ty.i32());
+    b.Append(f->Block(), [&] {  //
+        b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+        b.Store(W, 42_i);  // Prevents pattern match
+        auto* load = b.Load(W);
+        b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+        b.Return(f, load);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %W:ptr<workgroup, i32, read_write> = var
+}
+
+%f = func():i32 -> %b2 {
+  %b2 = block {
+    %3:void = workgroupBarrier
+    store %W, 42i
+    %4:i32 = load %W
+    %5:void = workgroupBarrier
+    ret %4
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %W:ptr<workgroup, i32, read_write> = var
+}
+
+%f = func():i32 -> %b2 {
+  %b2 = block {
+    %3:void = wgsl.workgroupBarrier
+    store %W, 42i
+    %4:i32 = load %W
+    %5:void = wgsl.workgroupBarrier
+    ret %4
+  }
+}
+)";
+
+    Run(Raise);
+
+    EXPECT_EQ(expect, str());
+}
+
 }  // namespace
 }  // namespace tint::wgsl::writer::raise