[ir][spirv-writer] Emit workgroup variables

If requested via the generator option, generate an OpConstantNull
instruction and use that for the initializer.

Bug: tint:1906
Change-Id: Iff2840508d89971964b85f5a9e2a478e913665b2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134743
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc
index c2a5bff..a3f3351 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc
@@ -172,6 +172,14 @@
     });
 }
 
+uint32_t GeneratorImplIr::ConstantNull(const type::Type* type) {
+    return constant_nulls_.GetOrCreate(type, [&]() {
+        auto id = module_.NextId();
+        module_.PushType(spv::Op::OpConstantNull, {Type(type), id});
+        return id;
+    });
+}
+
 uint32_t GeneratorImplIr::Type(const type::Type* ty) {
     return types_.GetOrCreate(ty, [&]() {
         auto id = module_.NextId();
@@ -726,9 +734,6 @@
 }
 
 uint32_t GeneratorImplIr::EmitVar(const ir::Var* var) {
-    // TODO(crbug.com/tint/1906): Remove this when we use it for emitting workgroup variables.
-    (void)zero_init_workgroup_memory_;
-
     auto id = module_.NextId();
     auto* ptr = var->Type()->As<type::Pointer>();
     TINT_ASSERT(Writer, ptr);
@@ -753,6 +758,17 @@
             module_.PushType(spv::Op::OpVariable, operands);
             break;
         }
+        case builtin::AddressSpace::kWorkgroup: {
+            TINT_ASSERT(Writer, !current_function_);
+            OperandList operands = {ty, id, U32Operand(SpvStorageClassWorkgroup)};
+            if (zero_init_workgroup_memory_) {
+                // If requested, use the VK_KHR_zero_initialize_workgroup_memory to zero-initialize
+                // the workgroup variable using an null constant initializer.
+                operands.push_back(ConstantNull(ptr->StoreType()));
+            }
+            module_.PushType(spv::Op::OpVariable, operands);
+            break;
+        }
         default: {
             TINT_ICE(Writer, diagnostics_)
                 << "unimplemented variable address space " << ptr->AddressSpace();
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.h b/src/tint/writer/spirv/ir/generator_impl_ir.h
index 4b7fe80..5793325 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir.h
+++ b/src/tint/writer/spirv/ir/generator_impl_ir.h
@@ -76,6 +76,11 @@
     /// @returns the result ID of the constant
     uint32_t Constant(const ir::Constant* constant);
 
+    /// Get the result ID of the OpConstantNull instruction for `type`, emitting it if necessary.
+    /// @param type the type to get the ID for
+    /// @returns the result ID of the OpConstantNull instruction
+    uint32_t ConstantNull(const type::Type* type);
+
     /// Get the result ID of the type `ty`, emitting a type declaration instruction if necessary.
     /// @param ty the type to get the ID for
     /// @returns the result ID of the type
@@ -198,6 +203,9 @@
     /// The map of constants to their result IDs.
     utils::Hashmap<const constant::Value*, uint32_t, 16> constants_;
 
+    /// The map of types to the result IDs of their OpConstantNull instructions.
+    utils::Hashmap<const type::Type*, uint32_t, 4> constant_nulls_;
+
     /// The map of non-constant values to their result IDs.
     utils::Hashmap<const ir::Value*, uint32_t, 8> values_;
 
diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
index 10686cd..8ad2c10 100644
--- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
+++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc
@@ -297,5 +297,118 @@
 )");
 }
 
+TEST_F(SpvGeneratorImplTest, WorkgroupVar) {
+    auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+                                              builtin::Access::kReadWrite);
+    b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+
+    generator_.Generate();
+    EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "unused_entry_point"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %4 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_Name) {
+    auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+                                              builtin::Access::kReadWrite);
+    auto* v = b.Declare(ty);
+    b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+    mod.SetName(v, "myvar");
+
+    generator_.Generate();
+    EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "unused_entry_point"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %1 "myvar"
+OpName %4 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_LoadAndStore) {
+    auto* func = b.CreateFunction("foo", mod.Types().void_(), ir::Function::PipelineStage::kCompute,
+                                  std::array{1u, 1u, 1u});
+    mod.functions.Push(func);
+
+    auto* store_ty = mod.Types().i32();
+    auto* ty = mod.Types().Get<type::Pointer>(store_ty, builtin::AddressSpace::kWorkgroup,
+                                              builtin::Access::kReadWrite);
+    auto* v = b.Declare(ty);
+    b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{v});
+
+    auto* load = b.Load(v);
+    auto* add = b.Add(store_ty, v, b.Constant(1_i));
+    auto* store = b.Store(v, add);
+    func->StartTarget()->SetInstructions(utils::Vector{load, add, store, b.Return(func)});
+
+    generator_.Generate();
+    EXPECT_EQ(DumpModule(generator_.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %4 "foo"
+OpExecutionMode %4 LocalSize 1 1 1
+OpName %4 "foo"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%1 = OpVariable %2 Workgroup
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%10 = OpConstant %3 1
+%4 = OpFunction %5 None %6
+%7 = OpLabel
+%8 = OpLoad %3 %1
+%9 = OpIAdd %3 %1 %10
+OpStore %1 %9
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(SpvGeneratorImplTest, WorkgroupVar_ZeroInitializeWithExtension) {
+    auto* ty = mod.Types().Get<type::Pointer>(mod.Types().i32(), builtin::AddressSpace::kWorkgroup,
+                                              builtin::Access::kReadWrite);
+    b.CreateRootBlockIfNeeded()->SetInstructions(utils::Vector{b.Declare(ty)});
+
+    // Create a generator with the zero_init_workgroup_memory flag set to `true`.
+    spirv::GeneratorImplIr gen(&mod, true);
+    gen.Generate();
+    EXPECT_EQ(DumpModule(gen.Module()), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %5 "unused_entry_point"
+OpExecutionMode %5 LocalSize 1 1 1
+OpName %5 "unused_entry_point"
+%3 = OpTypeInt 32 1
+%2 = OpTypePointer Workgroup %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Workgroup %4
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%5 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
 }  // namespace
 }  // namespace tint::writer::spirv