[spirv-reader][ir] Support propagating atomics through function params.
Add support for atomics passed as functions parameters to the spir-v IR
reader.
Bug: 406616153
Change-Id: I5170875b3e8e34aa4aa793fcc5ba2a57586b260f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/235615
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/reader/lower/atomics.cc b/src/tint/lang/spirv/reader/lower/atomics.cc
index 244d8ea..7c422ea 100644
--- a/src/tint/lang/spirv/reader/lower/atomics.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics.cc
@@ -198,21 +198,38 @@
}
void ConvertAtomicValue(core::ir::Value* val) {
- auto* res = val->As<core::ir::InstructionResult>();
- TINT_ASSERT(res);
+ tint::Switch( //
+ val, //
+ [&](core::ir::InstructionResult* res) {
+ auto* orig_ty = res->Type();
+ auto* atomic_ty = AtomicTypeFor(val, orig_ty);
+ res->SetType(atomic_ty);
- auto* orig_ty = res->Type();
- auto* atomic_ty = AtomicTypeFor(val, orig_ty);
- res->SetType(atomic_ty);
+ tint::Switch( //
+ res->Instruction(), //
+ [&](core::ir::Access* a) {
+ CheckForStructForking(a);
+ values_to_convert_.Push(a->Object());
+ }, //
+ [&](core::ir::Let* l) { values_to_convert_.Push(l->Value()); }, //
+ [&](core::ir::Var*) {}, //
+ TINT_ICE_ON_NO_MATCH);
+ },
+ [&](core::ir::FunctionParam* param) {
+ auto* orig_ty = param->Type();
+ auto* atomic_ty = AtomicTypeFor(val, orig_ty);
+ param->SetType(atomic_ty);
- tint::Switch( //
- res->Instruction(), //
- [&](core::ir::Access* a) {
- CheckForStructForking(a);
- values_to_convert_.Push(a->Object());
- }, //
- [&](core::ir::Let* l) { values_to_convert_.Push(l->Value()); }, //
- [&](core::ir::Var*) {}, //
+ for (auto& usage : param->Function()->UsagesUnsorted()) {
+ if (usage->instruction->Is<core::ir::Return>()) {
+ continue;
+ }
+
+ auto* call = usage->instruction->As<core::ir::Call>();
+ TINT_ASSERT(call);
+ values_to_convert_.Push(call->Args()[param->Index()]);
+ }
+ },
TINT_ICE_ON_NO_MATCH);
}
diff --git a/src/tint/lang/spirv/reader/lower/atomics_test.cc b/src/tint/lang/spirv/reader/lower/atomics_test.cc
index 95a82a5..0bf3fa7 100644
--- a/src/tint/lang/spirv/reader/lower/atomics_test.cc
+++ b/src/tint/lang/spirv/reader/lower/atomics_test.cc
@@ -691,6 +691,75 @@
ASSERT_EQ(expect, str());
}
+TEST_F(SpirvReader_AtomicsTest, FunctionParam) {
+ auto* c = b.Function("c", ty.void_());
+ auto* p = b.FunctionParam("param", ty.ptr(workgroup, ty.array<u32, 4>(), read_write));
+ c->SetParams({p});
+
+ b.Append(c->Block(), [&] {
+ auto* a = b.Access(ty.ptr<workgroup, u32, read_write>(), p, 1_i);
+ b.Call<spirv::ir::BuiltinCall>(ty.void_(), spirv::BuiltinFn::kAtomicStore, a, 2_u, 0_u,
+ 1_u);
+
+ b.Return(c);
+ });
+
+ auto* f = b.ComputeFunction("main");
+
+ core::ir::Var* wg = nullptr;
+ b.Append(mod.root_block,
+ [&] { wg = b.Var("wg", ty.ptr(workgroup, ty.array<u32, 4>(), read_write)); });
+
+ b.Append(f->Block(), [&] { //
+ b.Call(ty.void_(), c, wg);
+ b.Return(f);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %wg:ptr<workgroup, array<u32, 4>, read_write> = var undef
+}
+
+%c = func(%param:ptr<workgroup, array<u32, 4>, read_write>):void {
+ $B2: {
+ %4:ptr<workgroup, u32, read_write> = access %param, 1i
+ %5:void = spirv.atomic_store %4, 2u, 0u, 1u
+ ret
+ }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B3: {
+ %7:void = call %c, %wg
+ ret
+ }
+}
+)";
+
+ ASSERT_EQ(src, str());
+ Run(Atomics);
+
+ auto* expect = R"(
+$B1: { # root
+ %wg:ptr<workgroup, array<atomic<u32>, 4>, read_write> = var undef
+}
+
+%c = func(%param:ptr<workgroup, array<atomic<u32>, 4>, read_write>):void {
+ $B2: {
+ %4:ptr<workgroup, atomic<u32>, read_write> = access %param, 1i
+ %5:void = atomicStore %4, 1u
+ ret
+ }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B3: {
+ %7:void = call %c, %wg
+ ret
+ }
+}
+)";
+ ASSERT_EQ(expect, str());
+}
+
TEST_F(SpirvReader_AtomicsTest, AtomicAdd) {
auto* f = b.ComputeFunction("main");
diff --git a/src/tint/lang/spirv/reader/parser/atomics_test.cc b/src/tint/lang/spirv/reader/parser/atomics_test.cc
index 5e71efc..70344e3 100644
--- a/src/tint/lang/spirv/reader/parser/atomics_test.cc
+++ b/src/tint/lang/spirv/reader/parser/atomics_test.cc
@@ -577,6 +577,125 @@
)");
}
+TEST_F(SpirvParser_AtomicsTest, FunctionParam) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %wg "wg"
+ OpName %main "main"
+ %int = OpTypeInt 32 1
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+ %uint_4 = OpConstant %uint 4
+ %int_1 = OpConstant %int 1
+ %arr = OpTypeArray %uint %uint_4
+ %ptr_arr = OpTypePointer Workgroup %arr
+ %ptr_uint = OpTypePointer Workgroup %uint
+ %void = OpTypeVoid
+ %10 = OpTypeFunction %void
+ %11 = OpTypeFunction %void %ptr_arr
+ %wg = OpVariable %ptr_arr Workgroup
+
+ %foo = OpFunction %void None %11
+ %param = OpFunctionParameter %ptr_arr
+ %foo_start = OpLabel
+ %42 = OpAccessChain %ptr_uint %param %int_1
+ OpAtomicStore %42 %uint_2 %uint_0 %uint_1
+ OpReturn
+ OpFunctionEnd
+
+ %main = OpFunction %void None %10
+ %45 = OpLabel
+ %44 = OpFunctionCall %void %foo %wg
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+$B1: { # root
+ %wg:ptr<workgroup, array<u32, 4>, read_write> = var undef
+}
+
+%2 = func(%3:ptr<workgroup, array<u32, 4>, read_write>):void {
+ $B2: {
+ %4:ptr<workgroup, u32, read_write> = access %3, 1i
+ %5:void = spirv.atomic_store %4, 2u, 0u, 1u
+ ret
+ }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B3: {
+ %7:void = call %2, %wg
+ ret
+ }
+}
+)");
+}
+
+// TODO(dsinclair): Requires support for variable pointers
+TEST_F(SpirvParser_AtomicsTest, DISABLED_FunctionParam_subpointer) {
+ EXPECT_IR(R"(
+ OpCapability Shader
+ OpCapability VariablePointers
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpExecutionMode %main LocalSize 1 1 1
+ OpName %wg "wg"
+ OpName %main "main"
+ %int = OpTypeInt 32 1
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+ %uint_4 = OpConstant %uint 4
+ %int_1 = OpConstant %int 1
+ %arr = OpTypeArray %uint %uint_4
+ %ptr_arr = OpTypePointer Workgroup %arr
+ %ptr_uint = OpTypePointer Workgroup %uint
+ %void = OpTypeVoid
+ %10 = OpTypeFunction %void
+ %11 = OpTypeFunction %void %ptr_uint
+ %wg = OpVariable %ptr_arr Workgroup
+
+ %foo = OpFunction %void None %11
+ %param = OpFunctionParameter %ptr_uint
+ %foo_start = OpLabel
+ OpAtomicStore %param %uint_2 %uint_0 %uint_1
+ OpReturn
+ OpFunctionEnd
+
+ %main = OpFunction %void None %10
+ %45 = OpLabel
+ %42 = OpAccessChain %ptr_uint %wg %int_1
+ %44 = OpFunctionCall %void %foo %42
+ OpReturn
+ OpFunctionEnd
+)",
+ R"(
+$B1: { # root
+ %wg:ptr<workgroup, array<u32, 4>, read_write> = var undef
+}
+
+%2 = func(%3:ptr<workgroup, array<u32, 4>, read_write>):void {
+ $B2: {
+ %4:ptr<workgroup, u32, read_write> = access %3, 1i
+ %5:void = spirv.atomic_store %4, 2u, 0u, 1u
+ ret
+ }
+}
+%main = @compute @workgroup_size(1u, 1u, 1u) func():void {
+ $B3: {
+ %7:void = call %2, %wg
+ ret
+ }
+}
+)");
+}
+
TEST_F(SpirvParser_AtomicsTest, AtomicAdd) {
EXPECT_IR(R"(
OpCapability Shader