[tint][IrToProgram] Reconstruct workgroupUniformLoad
Required to prevent uniformity analysis errors
Change-Id: I9bdc9cd6a83c9e443359bb23b0f77d2f2b40af58
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/154502
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/wgsl/writer/raise/raise.cc b/src/tint/lang/wgsl/writer/raise/raise.cc
index d0ab5da..53ffb6c 100644
--- a/src/tint/lang/wgsl/writer/raise/raise.cc
+++ b/src/tint/lang/wgsl/writer/raise/raise.cc
@@ -18,6 +18,8 @@
#include "src/tint/lang/core/builtin_fn.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
+#include "src/tint/lang/core/ir/load.h"
+#include "src/tint/lang/core/type/pointer.h"
#include "src/tint/lang/wgsl/builtin_fn.h"
#include "src/tint/lang/wgsl/ir/builtin_call.h"
@@ -151,17 +153,66 @@
}
}
+void ReplaceBuiltinFnCall(core::ir::Module& mod, core::ir::CoreBuiltinCall* call) {
+ Vector<core::ir::Value*, 8> args(call->Args());
+ auto* replacement = mod.instructions.Create<wgsl::ir::BuiltinCall>(
+ call->Result(), Convert(call->Func()), std::move(args));
+ call->ReplaceWith(replacement);
+ call->ClearResults();
+ call->Destroy();
+}
+
+void ReplaceWorkgroupBarrier(core::ir::Module& mod, core::ir::CoreBuiltinCall* call) {
+ // Pattern match:
+ // call workgroupBarrier
+ // %value = load &ptr
+ // call workgroupBarrier
+ // And replace with:
+ // %value = call workgroupUniformLoad %ptr
+
+ auto* load = As<core::ir::Load>(call->next);
+ if (!load || load->From()->Type()->As<core::type::Pointer>()->AddressSpace() !=
+ core::AddressSpace::kWorkgroup) {
+ // No match
+ ReplaceBuiltinFnCall(mod, call);
+ return;
+ }
+
+ auto* post_load = As<core::ir::CoreBuiltinCall>(load->next);
+ if (!post_load || post_load->Func() != core::BuiltinFn::kWorkgroupBarrier) {
+ // No match
+ ReplaceBuiltinFnCall(mod, call);
+ return;
+ }
+
+ // Remove both calls to workgroupBarrier
+ post_load->Destroy();
+ call->Destroy();
+
+ // Replace load with workgroupUniformLoad
+ auto* replacement = mod.instructions.Create<wgsl::ir::BuiltinCall>(
+ load->Result(), wgsl::BuiltinFn::kWorkgroupUniformLoad, Vector{load->From()});
+ load->ReplaceWith(replacement);
+ load->ClearResults();
+ load->Destroy();
+}
+
} // namespace
Result<SuccessType> Raise(core::ir::Module& mod) {
for (auto* inst : mod.instructions.Objects()) {
+ if (!inst->Alive()) {
+ continue;
+ }
if (auto* call = inst->As<core::ir::CoreBuiltinCall>()) {
- Vector<core::ir::Value*, 8> args(call->Args());
- auto* replacement = mod.instructions.Create<wgsl::ir::BuiltinCall>(
- call->Result(), Convert(call->Func()), std::move(args));
- call->ReplaceWith(replacement);
- call->ClearResults();
- call->Destroy();
+ switch (call->Func()) {
+ case core::BuiltinFn::kWorkgroupBarrier:
+ ReplaceWorkgroupBarrier(mod, call);
+ break;
+ default:
+ ReplaceBuiltinFnCall(mod, call);
+ break;
+ }
}
}
return Success;
diff --git a/src/tint/lang/wgsl/writer/raise/raise_test.cc b/src/tint/lang/wgsl/writer/raise/raise_test.cc
index a4f9b05..6928dd4 100644
--- a/src/tint/lang/wgsl/writer/raise/raise_test.cc
+++ b/src/tint/lang/wgsl/writer/raise/raise_test.cc
@@ -57,5 +57,100 @@
EXPECT_EQ(expect, str());
}
+TEST_F(WgslWriter_RaiseTest, WorkgroupBarrier) {
+ auto* W = b.Var<workgroup, i32, read_write>("W");
+ b.ir.root_block->Append(W);
+ auto* f = b.Function("f", ty.i32());
+ b.Append(f->Block(), [&] { //
+ b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+ auto* load = b.Load(W);
+ b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+ b.Return(f, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %W:ptr<workgroup, i32, read_write> = var
+}
+
+%f = func():i32 -> %b2 {
+ %b2 = block {
+ %3:void = workgroupBarrier
+ %4:i32 = load %W
+ %5:void = workgroupBarrier
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %W:ptr<workgroup, i32, read_write> = var
+}
+
+%f = func():i32 -> %b2 {
+ %b2 = block {
+ %3:i32 = wgsl.workgroupUniformLoad %W
+ ret %3
+ }
+}
+)";
+
+ Run(Raise);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(WgslWriter_RaiseTest, WorkgroupBarrier_NoMatch) {
+ auto* W = b.Var<workgroup, i32, read_write>("W");
+ b.ir.root_block->Append(W);
+ auto* f = b.Function("f", ty.i32());
+ b.Append(f->Block(), [&] { //
+ b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+ b.Store(W, 42_i); // Prevents pattern match
+ auto* load = b.Load(W);
+ b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+ b.Return(f, load);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %W:ptr<workgroup, i32, read_write> = var
+}
+
+%f = func():i32 -> %b2 {
+ %b2 = block {
+ %3:void = workgroupBarrier
+ store %W, 42i
+ %4:i32 = load %W
+ %5:void = workgroupBarrier
+ ret %4
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %W:ptr<workgroup, i32, read_write> = var
+}
+
+%f = func():i32 -> %b2 {
+ %b2 = block {
+ %3:void = wgsl.workgroupBarrier
+ store %W, 42i
+ %4:i32 = load %W
+ %5:void = wgsl.workgroupBarrier
+ ret %4
+ }
+}
+)";
+
+ Run(Raise);
+
+ EXPECT_EQ(expect, str());
+}
+
} // namespace
} // namespace tint::wgsl::writer::raise