[spirv-reader][ir] Convert spirv atomic_store builtin.
Add initial atomic transform support and convert the
`spirv.atomic_store` builtin to a core `atomicStore` WGSL instruction.
Bug: 391487430
Change-Id: I03658d0d3f2e82e8717d9a18e4e85805da68ebeb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/233054
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/lower/atomics.cc b/src/tint/lang/spirv/reader/lower/atomics.cc
index c8dbe9c..ac91882 100644
--- a/src/tint/lang/spirv/reader/lower/atomics.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics.cc
@@ -30,6 +30,9 @@
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/spirv/ir/builtin_call.h"
+#include "src/tint/utils/containers/hashset.h"
+#include "src/tint/utils/containers/vector.h"
namespace tint::spirv::reader::lower {
namespace {
@@ -48,8 +51,103 @@
/// The type manager.
core::type::Manager& ty{ir.Types()};
+ /// The `ir::Value`s to be converted to atomics
+ Vector<core::ir::Value*, 8> values_to_convert_{};
+
+ /// The `ir::Value`s which have been converted
+ Hashset<core::ir::Value*, 8> converted_{};
+
/// Process the module.
- void Process() {}
+ void Process() {
+ Vector<spirv::ir::BuiltinCall*, 4> builtin_worklist;
+ for (auto* inst : ir.Instructions()) {
+ if (auto* builtin = inst->As<spirv::ir::BuiltinCall>()) {
+ builtin_worklist.Push(builtin);
+ }
+ }
+
+ for (auto* builtin : builtin_worklist) {
+ switch (builtin->Func()) {
+ case spirv::BuiltinFn::kAtomicLoad:
+ break;
+ case spirv::BuiltinFn::kAtomicStore:
+ AtomicStore(builtin);
+ break;
+ case spirv::BuiltinFn::kAtomicExchange:
+ case spirv::BuiltinFn::kAtomicCompareExchange:
+ case spirv::BuiltinFn::kAtomicIAdd:
+ case spirv::BuiltinFn::kAtomicISub:
+ case spirv::BuiltinFn::kAtomicSMax:
+ case spirv::BuiltinFn::kAtomicSMin:
+ case spirv::BuiltinFn::kAtomicUMax:
+ case spirv::BuiltinFn::kAtomicUMin:
+ case spirv::BuiltinFn::kAtomicAnd:
+ case spirv::BuiltinFn::kAtomicOr:
+ case spirv::BuiltinFn::kAtomicXor:
+ case spirv::BuiltinFn::kAtomicIIncrement:
+ case spirv::BuiltinFn::kAtomicIDecrement:
+ break;
+ default:
+ TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
+ }
+ }
+
+ while (!values_to_convert_.IsEmpty()) {
+ auto* val = values_to_convert_.Pop();
+
+ if (converted_.Add(val)) {
+ ConvertAtomicValue(val);
+ }
+ }
+ }
+
+ void AtomicStore(spirv::ir::BuiltinCall* call) {
+ auto args = call->Args();
+
+ b.InsertBefore(call, [&] {
+ auto* var = args[0];
+ values_to_convert_.Push(var);
+
+ auto* val = args[3];
+ b.CallWithResult(call->DetachResult(), core::BuiltinFn::kAtomicStore, var, val);
+ });
+ call->Destroy();
+ }
+
+ void ConvertAtomicValue(core::ir::Value* val) {
+ auto* res = val->As<core::ir::InstructionResult>();
+ TINT_ASSERT(res);
+
+ auto* atomic_ty = AtomicTypeFor(res->Type());
+ res->SetType(atomic_ty);
+
+ tint::Switch( //
+ res->Instruction(), //
+ [&](core::ir::Access* a) { values_to_convert_.Push(a->Object()); }, //
+ [&](core::ir::Var*) {}, //
+ TINT_ICE_ON_NO_MATCH);
+ }
+
+ const core::type::Type* AtomicTypeFor(const core::type::Type* orig_ty) {
+ return tint::Switch(
+ orig_ty, //
+ [&](const core::type::I32*) { return ty.atomic(orig_ty); },
+ [&](const core::type::U32*) { return ty.atomic(orig_ty); },
+ // [&](const core::type::Struct* str) { return ty(Fork(str).name); },
+ [&](const core::type::Array* arr) {
+ if (arr->Count()->Is<core::type::RuntimeArrayCount>()) {
+ return ty.runtime_array(AtomicTypeFor(arr->ElemType()));
+ }
+ auto count = arr->ConstantCount();
+ TINT_ASSERT(count);
+
+ return ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
+ },
+ [&](const core::type::Pointer* ptr) {
+ return ty.ptr(ptr->AddressSpace(), AtomicTypeFor(ptr->StoreType()), ptr->Access());
+ },
+ TINT_ICE_ON_NO_MATCH);
+ }
};
} // namespace
diff --git a/src/tint/lang/spirv/reader/lower/atomics_test.cc b/src/tint/lang/spirv/reader/lower/atomics_test.cc
index 5450a7e..d3bef8e 100644
--- a/src/tint/lang/spirv/reader/lower/atomics_test.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics_test.cc
@@ -38,7 +38,7 @@
using SpirvReader_AtomicsTest = core::ir::transform::TransformTest;
-TEST_F(SpirvReader_AtomicsTest, DISABLED_ArrayStore) {
+TEST_F(SpirvReader_AtomicsTest, ArrayStore) {
auto* f = b.ComputeFunction("main");
core::ir::Var* wg = nullptr;
@@ -48,7 +48,7 @@
b.Append(f->Block(), [&] { //
auto* a = b.Access(ty.ptr<workgroup, u32, read_write>(), wg, 1_i);
b.Call<spirv::ir::BuiltinCall>(ty.void_(), spirv::BuiltinFn::kAtomicStore, a, 1_u, 0_u,
- 0_u);
+ 1_u);
b.Return(f);
});
@@ -60,7 +60,7 @@
%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
$B2: {
%3:ptr<workgroup, u32, read_write> = access %wg, 1i
- %4:void = spirv.atomic_store %3, 1u, 0u, 0u
+ %4:void = spirv.atomic_store %3, 1u, 0u, 1u
ret
}
}
@@ -69,7 +69,17 @@
Run(Atomics);
auto* expect = R"(
-UNIMPLEMENTED
+$B1: { # root
+ %wg:ptr<workgroup, array<atomic<u32>, 4>, read_write> = var undef
+}
+
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B2: {
+ %3:ptr<workgroup, atomic<u32>, read_write> = access %wg, 1i
+ %4:void = atomicStore %3, 1u
+ ret
+ }
+}
)";
ASSERT_EQ(expect, str());
}
@@ -112,7 +122,7 @@
ASSERT_EQ(expect, str());
}
-TEST_F(SpirvReader_AtomicsTest, DISABLED_ArrayNested) {
+TEST_F(SpirvReader_AtomicsTest, ArrayNested) {
auto* f = b.ComputeFunction("main");
core::ir::Var* wg = nullptr;
@@ -144,7 +154,17 @@
Run(Atomics);
auto* expect = R"(
-UNIMPLEMENTED
+$B1: { # root
+ %wg:ptr<workgroup, array<array<array<atomic<u32>, 1>, 2>, 3>, read_write> = var undef
+}
+
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B2: {
+ %3:ptr<workgroup, atomic<u32>, read_write> = access %wg, 2i, 1i, 0i
+ %4:void = atomicStore %3, 1u
+ ret
+ }
+}
)";
ASSERT_EQ(expect, str());
}
diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc
index 0d194d3..28cdce7 100644
--- a/src/tint/lang/spirv/reader/lower/builtins.cc
+++ b/src/tint/lang/spirv/reader/lower/builtins.cc
@@ -218,6 +218,23 @@
case spirv::BuiltinFn::kOuterProduct:
OuterProduct(builtin);
break;
+ case spirv::BuiltinFn::kAtomicLoad:
+ case spirv::BuiltinFn::kAtomicStore:
+ case spirv::BuiltinFn::kAtomicExchange:
+ case spirv::BuiltinFn::kAtomicCompareExchange:
+ case spirv::BuiltinFn::kAtomicIAdd:
+ case spirv::BuiltinFn::kAtomicISub:
+ case spirv::BuiltinFn::kAtomicSMax:
+ case spirv::BuiltinFn::kAtomicSMin:
+ case spirv::BuiltinFn::kAtomicUMax:
+ case spirv::BuiltinFn::kAtomicUMin:
+ case spirv::BuiltinFn::kAtomicAnd:
+ case spirv::BuiltinFn::kAtomicOr:
+ case spirv::BuiltinFn::kAtomicXor:
+ case spirv::BuiltinFn::kAtomicIIncrement:
+ case spirv::BuiltinFn::kAtomicIDecrement:
+ // Ignore Atomics, they'll be handled by the `Atomics` transform.
+ break;
default:
TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func();
}
diff --git a/src/tint/lang/spirv/reader/lower/lower.cc b/src/tint/lang/spirv/reader/lower/lower.cc
index 86d7241..c5e4fb6 100644
--- a/src/tint/lang/spirv/reader/lower/lower.cc
+++ b/src/tint/lang/spirv/reader/lower/lower.cc
@@ -29,6 +29,7 @@
#include "src/tint/lang/core/ir/transform/remove_terminator_args.h"
#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/spirv/reader/lower/atomics.h"
#include "src/tint/lang/spirv/reader/lower/builtins.h"
#include "src/tint/lang/spirv/reader/lower/shader_io.h"
#include "src/tint/lang/spirv/reader/lower/vector_element_pointer.h"
@@ -47,6 +48,7 @@
RUN_TRANSFORM(lower::VectorElementPointer, mod);
RUN_TRANSFORM(lower::ShaderIO, mod);
RUN_TRANSFORM(lower::Builtins, mod);
+ RUN_TRANSFORM(lower::Atomics, mod);
// Remove the terminator args at this point. There are no logical short-circuiting operators in
// SPIR-V that we will lose track of, all the terminators are for hoisted values. We don't do