[spirv][ir] Modify the memory model if Vulkan Memory Model requested

If the Vulkan Memory Model is requested add the needed header entries
and disable the emission of `Coherent`.

Bug: 348702031
Change-Id: Id82dade9c297c8b34f86e5cf6ad640c73b655fc0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/202995
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/writer/common/helper_test.h b/src/tint/lang/spirv/writer/common/helper_test.h
index 56b52bb..3e1cbbc 100644
--- a/src/tint/lang/spirv/writer/common/helper_test.h
+++ b/src/tint/lang/spirv/writer/common/helper_test.h
@@ -122,7 +122,12 @@
             return false;
         }
 
-        auto spirv = PrintModule(mod, zero_init_workgroup_memory);
+        if (zero_init_workgroup_memory) {
+            options.disable_workgroup_init = false;
+            options.use_zero_initialize_workgroup_memory_extension = true;
+        }
+
+        auto spirv = PrintModule(mod, options);
         if (spirv != Success) {
             err_ = spirv.Failure().reason.Str();
             return false;
diff --git a/src/tint/lang/spirv/writer/common/options.h b/src/tint/lang/spirv/writer/common/options.h
index 9af79ab..6d7fd7f 100644
--- a/src/tint/lang/spirv/writer/common/options.h
+++ b/src/tint/lang/spirv/writer/common/options.h
@@ -178,6 +178,9 @@
     /// Set to `true` to disable the polyfills on integer division and modulo.
     bool disable_polyfill_integer_div_mod = false;
 
+    /// Set to `true` if the Vulkan Memory Model should be used
+    bool use_vulkan_memory_model = false;
+
     /// Reflect the fields of this class so that it can be used by tint::ForeachField()
     TINT_REFLECT(Options,
                  bindings,
@@ -193,7 +196,8 @@
                  pass_matrix_by_pointer,
                  experimental_require_subgroup_uniform_control_flow,
                  polyfill_dot_4x8_packed,
-                 disable_polyfill_integer_div_mod);
+                 disable_polyfill_integer_div_mod,
+                 use_vulkan_memory_model);
 };
 
 }  // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 47a150a..721d72c 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -98,6 +98,7 @@
 #include "src/tint/lang/spirv/writer/common/binary_writer.h"
 #include "src/tint/lang/spirv/writer/common/function.h"
 #include "src/tint/lang/spirv/writer/common/module.h"
+#include "src/tint/lang/spirv/writer/common/options.h"
 #include "src/tint/lang/spirv/writer/raise/builtin_polyfill.h"
 #include "src/tint/utils/containers/hashmap.h"
 #include "src/tint/utils/containers/vector.h"
@@ -183,10 +184,12 @@
   public:
     /// Constructor
     /// @param module the Tint IR module to generate
-    /// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
-    ///                                   storage class with OpConstantNull
-    Printer(core::ir::Module& module, bool zero_init_workgroup_memory)
-        : ir_(module), b_(module), zero_init_workgroup_memory_(zero_init_workgroup_memory) {}
+    /// @param options the printer options
+    Printer(core::ir::Module& module, const Options& options)
+        : ir_(module), b_(module), options_(options) {
+        zero_init_workgroup_memory_ = !options.disable_workgroup_init &&
+                                      options.use_zero_initialize_workgroup_memory_extension;
+    }
 
     /// @returns the generated SPIR-V code on success, or failure
     Result<std::vector<uint32_t>> Code() {
@@ -218,6 +221,7 @@
   private:
     core::ir::Module& ir_;
     core::ir::Builder b_;
+    Options options_;
     writer::Module module_;
     BinaryWriter writer_;
 
@@ -294,8 +298,18 @@
         }
 
         module_.PushCapability(SpvCapabilityShader);
-        module_.PushMemoryModel(spv::Op::OpMemoryModel, {U32Operand(SpvAddressingModelLogical),
-                                                         U32Operand(SpvMemoryModelGLSL450)});
+
+        if (options_.use_vulkan_memory_model) {
+            module_.PushExtension("SPV_KHR_vulkan_memory_model");
+            module_.PushCapability(SpvCapabilityVulkanMemoryModelKHR);
+            // Required for the `Device` scope on atomic operations
+            module_.PushCapability(SpvCapabilityVulkanMemoryModelDeviceScopeKHR);
+            module_.PushMemoryModel(spv::Op::OpMemoryModel, {U32Operand(SpvAddressingModelLogical),
+                                                             U32Operand(SpvMemoryModelVulkanKHR)});
+        } else {
+            module_.PushMemoryModel(spv::Op::OpMemoryModel, {U32Operand(SpvAddressingModelLogical),
+                                                             U32Operand(SpvMemoryModelGLSL450)});
+        }
 
         // Emit module-scope declarations.
         EmitRootBlock(ir_.root_block);
@@ -2289,7 +2303,7 @@
                                           {id, U32Operand(SpvDecorationNonReadable)});
                     }
                 }
-                if (access == core::Access::kReadWrite) {
+                if (!options_.use_vulkan_memory_model && access == core::Access::kReadWrite) {
                     module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationCoherent)});
                 }
 
@@ -2454,13 +2468,12 @@
 
 }  // namespace
 
-tint::Result<std::vector<uint32_t>> Print(core::ir::Module& module,
-                                          bool zero_init_workgroup_memory) {
-    return Printer{module, zero_init_workgroup_memory}.Code();
+tint::Result<std::vector<uint32_t>> Print(core::ir::Module& module, const Options& options) {
+    return Printer{module, options}.Code();
 }
 
-tint::Result<Module> PrintModule(core::ir::Module& module, bool zero_init_workgroup_memory) {
-    return Printer{module, zero_init_workgroup_memory}.Module();
+tint::Result<Module> PrintModule(core::ir::Module& module, const Options& options) {
+    return Printer{module, options}.Module();
 }
 
 }  // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/printer/printer.h b/src/tint/lang/spirv/writer/printer/printer.h
index e82b01c..79abc31 100644
--- a/src/tint/lang/spirv/writer/printer/printer.h
+++ b/src/tint/lang/spirv/writer/printer/printer.h
@@ -32,6 +32,7 @@
 #include <vector>
 
 #include "src/tint/lang/spirv/writer/common/module.h"
+#include "src/tint/lang/spirv/writer/common/options.h"
 #include "src/tint/utils/result/result.h"
 
 // Forward declarations
@@ -43,16 +44,13 @@
 
 /// @returns the generated SPIR-V instructions on success, or failure
 /// @param module the Tint IR module to generate
-/// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
-///                                   storage class with OpConstantNull
-tint::Result<std::vector<uint32_t>> Print(core::ir::Module& module,
-                                          bool zero_init_workgroup_memory);
+/// @param options the printer options
+tint::Result<std::vector<uint32_t>> Print(core::ir::Module& module, const Options& options);
 
 /// @returns the generated SPIR-V module on success, or failure
 /// @param module the Tint IR module to generate
-/// @param zero_init_workgroup_memory `true` to initialize all the variables in the Workgroup
-///                                   storage class with OpConstantNull
-tint::Result<Module> PrintModule(core::ir::Module& module, bool zero_init_workgroup_memory);
+/// @param options the printer options
+tint::Result<Module> PrintModule(core::ir::Module& module, const Options& options);
 
 }  // namespace tint::spirv::writer
 
diff --git a/src/tint/lang/spirv/writer/var_test.cc b/src/tint/lang/spirv/writer/var_test.cc
index 01d2da3..24b269a 100644
--- a/src/tint/lang/spirv/writer/var_test.cc
+++ b/src/tint/lang/spirv/writer/var_test.cc
@@ -236,6 +236,71 @@
 )");
 }
 
+TEST_F(SpirvWriterTest, StorageVar_NoCoherentWithVulkan) {
+    auto* v = b.Var("v", ty.ptr<storage, i32, read_write>());
+    v->SetBindingPoint(0, 0);
+    mod.root_block->Append(v);
+
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute,
+                            std::array{1u, 1u, 1u});
+    b.Append(func->Block(), [&] {
+        auto* load = b.Load(v);
+        auto* add = b.Add(ty.i32(), load, 1_i);
+        b.Store(v, add);
+        b.Return(func);
+        mod.SetName(load, "load");
+        mod.SetName(add, "add");
+    });
+
+    Options opts;
+    opts.use_vulkan_memory_model = true;
+
+    ASSERT_TRUE(Generate(opts)) << Error() << output_;
+    EXPECT_INST(R"(               OpCapability Shader
+               OpCapability VulkanMemoryModel
+               OpCapability VulkanMemoryModelDeviceScope
+               OpExtension "SPV_KHR_vulkan_memory_model"
+               OpMemoryModel Logical Vulkan
+               OpEntryPoint GLCompute %foo "foo"
+               OpExecutionMode %foo LocalSize 1 1 1
+
+               ; Debug Information
+               OpMemberName %tint_symbol_1 0 "tint_symbol"
+               OpName %tint_symbol_1 "tint_symbol_1"    ; id %3
+               OpName %foo "foo"                        ; id %5
+               OpName %load "load"                      ; id %13
+               OpName %add "add"                        ; id %14
+
+               ; Annotations
+               OpMemberDecorate %tint_symbol_1 0 Offset 0
+               OpDecorate %tint_symbol_1 Block
+               OpDecorate %1 DescriptorSet 0
+               OpDecorate %1 Binding 0
+
+               ; Types, variables and constants
+        %int = OpTypeInt 32 1
+%tint_symbol_1 = OpTypeStruct %int                  ; Block
+%_ptr_StorageBuffer_tint_symbol_1 = OpTypePointer StorageBuffer %tint_symbol_1
+          %1 = OpVariable %_ptr_StorageBuffer_tint_symbol_1 StorageBuffer   ; DescriptorSet 0, Binding 0
+       %void = OpTypeVoid
+          %7 = OpTypeFunction %void
+%_ptr_StorageBuffer_int = OpTypePointer StorageBuffer %int
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+      %int_1 = OpConstant %int 1
+
+               ; Function foo
+        %foo = OpFunction %void None %7
+          %8 = OpLabel
+          %9 = OpAccessChain %_ptr_StorageBuffer_int %1 %uint_0
+       %load = OpLoad %int %9
+        %add = OpIAdd %int %load %int_1
+         %16 = OpAccessChain %_ptr_StorageBuffer_int %1 %uint_0
+               OpStore %16 %add
+               OpReturn
+               OpFunctionEnd)");
+}
+
 TEST_F(SpirvWriterTest, StorageVar_WriteOnly) {
     auto* v = b.Var("v", ty.ptr<storage, i32, write>());
     v->SetBindingPoint(0, 0);
diff --git a/src/tint/lang/spirv/writer/writer.cc b/src/tint/lang/spirv/writer/writer.cc
index d172299..5b30750 100644
--- a/src/tint/lang/spirv/writer/writer.cc
+++ b/src/tint/lang/spirv/writer/writer.cc
@@ -40,9 +40,6 @@
 namespace tint::spirv::writer {
 
 Result<Output> Generate(core::ir::Module& ir, const Options& options) {
-    bool zero_initialize_workgroup_memory =
-        !options.disable_workgroup_init && options.use_zero_initialize_workgroup_memory_extension;
-
     {
         auto res = ValidateBindingOptions(options);
         if (res != Success) {
@@ -58,7 +55,7 @@
     }
 
     // Generate the SPIR-V code.
-    auto spirv = Print(ir, zero_initialize_workgroup_memory);
+    auto spirv = Print(ir, options);
     if (spirv != Success) {
         return std::move(spirv.Failure());
     }
diff --git a/src/tint/lang/spirv/writer/writer_test.cc b/src/tint/lang/spirv/writer/writer_test.cc
index c00563f..eb4e367 100644
--- a/src/tint/lang/spirv/writer/writer_test.cc
+++ b/src/tint/lang/spirv/writer/writer_test.cc
@@ -40,6 +40,17 @@
     EXPECT_INST("OpMemoryModel Logical GLSL450");
 }
 
+TEST_F(SpirvWriterTest, ModuleHeader_VulkanMemoryModel) {
+    Options opts;
+    opts.use_vulkan_memory_model = true;
+
+    ASSERT_TRUE(Generate(opts)) << Error() << output_;
+    EXPECT_INST("OpExtension \"SPV_KHR_vulkan_memory_model\"");
+    EXPECT_INST("OpCapability VulkanMemoryModel");
+    EXPECT_INST("OpCapability VulkanMemoryModelDeviceScope");
+    EXPECT_INST("OpMemoryModel Logical Vulkan");
+}
+
 TEST_F(SpirvWriterTest, Unreachable) {
     auto* func = b.Function("foo", ty.void_());
     b.Append(func->Block(), [&] {