[ir][spirv-writer] Emit workgroup variables
If requested via the generator option, generate an OpConstantNull
instruction and use that for the initializer.
Bug: tint:1906
Change-Id: Iff2840508d89971964b85f5a9e2a478e913665b2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134743
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index c2a5bff..a3f3351 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -172,6 +172,14 @@
});
}
+uint32_t GeneratorImplIr::ConstantNull(const type::Type* type) {
+ return constant_nulls_.GetOrCreate(type, [&]() {
+ auto id = module_.NextId();
+ module_.PushType(spv::Op::OpConstantNull, {Type(type), id});
+ return id;
+ });
+}
+
uint32_t GeneratorImplIr::Type(const type::Type* ty) {
return types_.GetOrCreate(ty, [&]() {
auto id = module_.NextId();
@@ -726,9 +734,6 @@
}
uint32_t GeneratorImplIr::EmitVar(const ir::Var* var) {
- // TODO(crbug.com/tint/1906): Remove this when we use it for emitting workgroup variables.
- (void)zero_init_workgroup_memory_;
-
auto id = module_.NextId();
auto* ptr = var->Type()->As<type::Pointer>();
TINT_ASSERT(Writer, ptr);
@@ -753,6 +758,17 @@
module_.PushType(spv::Op::OpVariable, operands);
break;
}
+ case builtin::AddressSpace::kWorkgroup: {
+ TINT_ASSERT(Writer, !current_function_);
+ OperandList operands = {ty, id, U32Operand(SpvStorageClassWorkgroup)};
+ if (zero_init_workgroup_memory_) {
+ // If requested, use the VK_KHR_zero_initialize_workgroup_memory to zero-initialize
+ // the workgroup variable using an null constant initializer.
+ operands.push_back(ConstantNull(ptr->StoreType()));
+ }
+ module_.PushType(spv::Op::OpVariable, operands);
+ break;
+ }
default: {
TINT_ICE(Writer, diagnostics_)
<< "unimplemented variable address space " << ptr->AddressSpace();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index 4b7fe80..5793325 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -76,6 +76,11 @@
/// @returns the result ID of the constant
uint32_t Constant(const ir::Constant* constant);
+ /// Get the result ID of the OpConstantNull instruction for `type`, emitting it if necessary.
+ /// @param type the type to get the ID for
+ /// @returns the result ID of the OpConstantNull instruction
+ uint32_t ConstantNull(const type::Type* type);
+
/// Get the result ID of the type `ty`, emitting a type declaration instruction if necessary.
/// @param ty the type to get the ID for
/// @returns the result ID of the type
@@ -198,6 +203,9 @@
/// The map of constants to their result IDs.
utils::Hashmap<const constant::Value*, uint32_t, 16> constants_;
+ /// The map of types to the result IDs of their OpConstantNull instructions.
+ utils::Hashmap<const type::Type*, uint32_t, 4> constant_nulls_;
+
/// The map of non-constant values to their result IDs.
utils::Hashmap<const ir::Value*, uint32_t, 8> values_;
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index 10686cd..8ad2c10 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -297,5 +297,118 @@
)");
}
+TEST_F(SpvGeneratorImplTest, WorkgroupVar) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+ builtin::Access::kReadWrite);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "unused_entry_point"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %4 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_Name) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+ builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+ mod.SetName(v, "myvar");
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "unused_entry_point"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %1 "myvar"
+OpName %4 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_LoadAndStore) {
+ auto* func = b.CreateFunction("foo", mod.Types().void_(), ir::Function::PipelineStage::kCompute,
+ std::array{1u, 1u, 1u});
+ mod.functions.Push(func);
+
+ auto* store_ty = mod.Types().i32();
+ auto* ty = mod.Types().Get<type::Pointer>(store_ty, builtin::AddressSpace::kWorkgroup,
+ builtin::Access::kReadWrite);
+ auto* v = b.Declare(ty);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+
+ auto* load = b.Load(v);
+ auto* add = b.Add(store_ty, v, b.Constant(1_i));
+ auto* store = b.Store(v, add);
+ func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
+
+ generator_.Generate();
+ EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "foo"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %4 "foo"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%10 = OpConstant %3 1
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+%8 = OpLoad %3 %1
+%9 = OpIAdd %3 %1 %10
+OpStore %1 %9
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_ZeroInitializeWithExtension) {
+ auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+ builtin::Access::kReadWrite);
+ b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+
+ // Create a generator with the zero_init_workgroup_memory flag set to `true`.
+ spirv::GeneratorImplIr gen(&mod, true);
+ gen.Generate();
+ EXPECT_EQ(DumpModule(gen.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpName %5 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Workgroup %4
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace tint::writer::spirv