writers: Add flag to disable workgroup memory init

Bug: tint:1003
Change-Id: Ia30a2c51b5d3f8c6a01bed5299eac51dc3ad6337
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58843
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/.vscode/tasks.json b/.vscode/tasks.json
index f0482fc..ee4089f 100644
--- a/.vscode/tasks.json
+++ b/.vscode/tasks.json
@@ -99,7 +99,7 @@
             },
             "windows": {
                 // Generates a GN build directory at 'out/<build-type>' with the
-                // is_debug argument set to to true iff the build-type is Debug.
+                // is_debug argument set to true iff the build-type is Debug.
                 // A symbolic link to this build directory is created at 'out/active'
                 // which is used to track the active build directory.
                 "command": "/C",
@@ -143,4 +143,4 @@
             "description": "The type of build",
         },
     ]
-}
\ No newline at end of file
+}
diff --git a/samples/main.cc b/samples/main.cc
index dd7970f..e76342d 100644
--- a/samples/main.cc
+++ b/samples/main.cc
@@ -54,6 +54,7 @@
 
   bool parse_only = false;
   bool dump_ast = false;
+  bool disable_workgroup_init = false;
   bool validate = false;
   bool demangle = false;
   bool dump_inspector_bindings = false;
@@ -93,6 +94,7 @@
                                 robustness
   --parse-only              -- Stop after parsing the input
   --dump-ast                -- Dump the generated AST to stdout
+  --disable-workgroup-init  -- Disable workgroup memory zero initialization.
   --demangle                -- Preserve original source names. Demangle them.
                                Affects AST dumping, and text-based output languages.
   --dump-inspector-bindings -- Dump reflection data about bindins to stdout.
@@ -401,6 +403,8 @@
       opts->parse_only = true;
     } else if (arg == "--dump-ast") {
       opts->dump_ast = true;
+    } else if (arg == "--disable-workgroup-init") {
+      opts->disable_workgroup_init = true;
     } else if (arg == "--demangle") {
       opts->demangle = true;
     } else if (arg == "--dump-inspector-bindings") {
@@ -603,6 +607,7 @@
 #if TINT_BUILD_SPV_WRITER
   // TODO(jrprice): Provide a way for the user to set non-default options.
   tint::writer::spirv::Options gen_options;
+  gen_options.disable_workgroup_init = options.disable_workgroup_init;
   auto result = tint::writer::spirv::Generate(program, gen_options);
   if (!result.success) {
     PrintWGSL(std::cerr, *program);
@@ -670,6 +675,7 @@
 #if TINT_BUILD_MSL_WRITER
   // TODO(jrprice): Provide a way for the user to set non-default options.
   tint::writer::msl::Options gen_options;
+  gen_options.disable_workgroup_init = options.disable_workgroup_init;
   auto result = tint::writer::msl::Generate(program, gen_options);
   if (!result.success) {
     PrintWGSL(std::cerr, *program);
@@ -721,6 +727,7 @@
 #if TINT_BUILD_HLSL_WRITER
   // TODO(jrprice): Provide a way for the user to set non-default options.
   tint::writer::hlsl::Options gen_options;
+  gen_options.disable_workgroup_init = options.disable_workgroup_init;
   auto result = tint::writer::hlsl::Generate(program, gen_options);
   if (!result.success) {
     PrintWGSL(std::cerr, *program);
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index baf8897..26cd42c 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -31,6 +31,7 @@
 #include "src/transform/zero_init_workgroup_memory.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::Hlsl);
+TINT_INSTANTIATE_TYPEINFO(tint::transform::Hlsl::Config);
 
 namespace tint {
 namespace transform {
@@ -38,18 +39,22 @@
 Hlsl::Hlsl() = default;
 Hlsl::~Hlsl() = default;
 
-Output Hlsl::Run(const Program* in, const DataMap&) {
+Output Hlsl::Run(const Program* in, const DataMap& inputs) {
   Manager manager;
   DataMap data;
 
+  auto* cfg = inputs.Get<Config>();
+
   // Attempt to convert `loop`s into for-loops. This is to try and massage the
   // output into something that will not cause FXC to choke or misbehave.
   manager.Add<FoldTrivialSingleUseLets>();
   manager.Add<LoopToForLoop>();
 
-  // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
-  // ZeroInitWorkgroupMemory may inject new builtin parameters.
-  manager.Add<ZeroInitWorkgroupMemory>();
+  if (!cfg || !cfg->disable_workgroup_init) {
+    // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
+    // ZeroInitWorkgroupMemory may inject new builtin parameters.
+    manager.Add<ZeroInitWorkgroupMemory>();
+  }
   manager.Add<CanonicalizeEntryPointIO>();
   manager.Add<InlinePointerLets>();
   // Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets.
@@ -97,5 +102,9 @@
                  ctx.dst->WorkgroupSize(1)});
 }
 
+Hlsl::Config::Config(bool disable_wi) : disable_workgroup_init(disable_wi) {}
+Hlsl::Config::Config(const Config&) = default;
+Hlsl::Config::~Config() = default;
+
 }  // namespace transform
 }  // namespace tint
diff --git a/src/transform/hlsl.h b/src/transform/hlsl.h
index 091903e..01aa2d5 100644
--- a/src/transform/hlsl.h
+++ b/src/transform/hlsl.h
@@ -29,6 +29,23 @@
 /// behavior.
 class Hlsl : public Castable<Hlsl, Transform> {
  public:
+  /// Configuration options for the Hlsl sanitizer transform.
+  struct Config : public Castable<Data, transform::Data> {
+    /// Constructor
+    /// @param disable_workgroup_init `true` to disable workgroup memory zero
+    ///        initialization
+    explicit Config(bool disable_workgroup_init = false);
+
+    /// Copy constructor
+    Config(const Config&);
+
+    /// Destructor
+    ~Config() override;
+
+    /// Set to `true` to disable workgroup memory zero initialization
+    bool disable_workgroup_init = false;
+  };
+
   /// Constructor
   Hlsl();
   ~Hlsl() override;
diff --git a/src/transform/msl.cc b/src/transform/msl.cc
index 38e60f4..53ead18 100644
--- a/src/transform/msl.cc
+++ b/src/transform/msl.cc
@@ -73,9 +73,11 @@
     }
   }
 
-  // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
-  // ZeroInitWorkgroupMemory may inject new builtin parameters.
-  manager.Add<ZeroInitWorkgroupMemory>();
+  if (!cfg || !cfg->disable_workgroup_init) {
+    // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
+    // ZeroInitWorkgroupMemory may inject new builtin parameters.
+    manager.Add<ZeroInitWorkgroupMemory>();
+  }
   manager.Add<CanonicalizeEntryPointIO>();
   manager.Add<ExternalTextureTransform>();
   manager.Add<PromoteInitializersToConstVar>();
@@ -280,9 +282,12 @@
   }
 }
 
-Msl::Config::Config(uint32_t buffer_size_ubo_idx, uint32_t sample_mask)
+Msl::Config::Config(uint32_t buffer_size_ubo_idx,
+                    uint32_t sample_mask,
+                    bool disable_wi)
     : buffer_size_ubo_index(buffer_size_ubo_idx),
-      fixed_sample_mask(sample_mask) {}
+      fixed_sample_mask(sample_mask),
+      disable_workgroup_init(disable_wi) {}
 Msl::Config::Config(const Config&) = default;
 Msl::Config::~Config() = default;
 
diff --git a/src/transform/msl.h b/src/transform/msl.h
index 0fb6cb5..bb9a013 100644
--- a/src/transform/msl.h
+++ b/src/transform/msl.h
@@ -33,8 +33,11 @@
     /// Constructor
     /// @param buffer_size_ubo_idx the index to use for the buffer size UBO
     /// @param sample_mask the fixed sample mask to use for fragment shaders
-    explicit Config(uint32_t buffer_size_ubo_idx,
-                    uint32_t sample_mask = 0xFFFFFFFF);
+    /// @param disable_workgroup_init `true` to disable workgroup memory zero
+    ///        initialization
+    Config(uint32_t buffer_size_ubo_idx,
+           uint32_t sample_mask = 0xFFFFFFFF,
+           bool disable_workgroup_init = false);
 
     /// Copy constructor
     Config(const Config&);
@@ -43,10 +46,13 @@
     ~Config() override;
 
     /// The index to use when generating a UBO to receive storage buffer sizes.
-    uint32_t buffer_size_ubo_index;
+    uint32_t buffer_size_ubo_index = 0;
 
     /// The fixed sample mask to combine with fragment shader outputs.
-    uint32_t fixed_sample_mask;
+    uint32_t fixed_sample_mask = 0xFFFFFFFF;
+
+    /// Set to `true` to disable workgroup memory zero initialization
+    bool disable_workgroup_init = false;
   };
 
   /// Information produced by the sanitizer that users may need to act on.
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 0ce47c6..2b466b1 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -45,8 +45,12 @@
 Spirv::~Spirv() = default;
 
 Output Spirv::Run(const Program* in, const DataMap& data) {
+  auto* cfg = data.Get<Config>();
+
   Manager manager;
-  manager.Add<ZeroInitWorkgroupMemory>();
+  if (!cfg || !cfg->disable_workgroup_init) {
+    manager.Add<ZeroInitWorkgroupMemory>();
+  }
   manager.Add<InlinePointerLets>();  // Required for arrayLength()
   manager.Add<Simplify>();           // Required for arrayLength()
   manager.Add<FoldConstants>();
@@ -58,8 +62,6 @@
     return transformedInput;
   }
 
-  auto* cfg = data.Get<Config>();
-
   ProgramBuilder out;
   CloneContext ctx(&out, &transformedInput.program);
   HandleEntryPointIOTypes(ctx);
@@ -427,7 +429,8 @@
   }
 }
 
-Spirv::Config::Config(bool emit_vps) : emit_vertex_point_size(emit_vps) {}
+Spirv::Config::Config(bool emit_vps, bool disable_wi)
+    : emit_vertex_point_size(emit_vps), disable_workgroup_init(disable_wi) {}
 
 Spirv::Config::Config(const Config&) = default;
 Spirv::Config::~Config() = default;
diff --git a/src/transform/spirv.h b/src/transform/spirv.h
index dba32b5..0b85ebe 100644
--- a/src/transform/spirv.h
+++ b/src/transform/spirv.h
@@ -35,7 +35,10 @@
   struct Config : public Castable<Config, Data> {
     /// Constructor
     /// @param emit_vertex_point_size `true` to generate a PointSize builtin
-    explicit Config(bool emit_vertex_point_size = false);
+    /// @param disable_workgroup_init `true` to disable workgroup memory zero
+    ///        initialization
+    Config(bool emit_vertex_point_size = false,
+           bool disable_workgroup_init = false);
 
     /// Copy constructor.
     Config(const Config&);
@@ -49,7 +52,10 @@
 
     /// Set to `true` to generate a PointSize builtin and have it set to 1.0
     /// from all vertex shaders in the module.
-    bool emit_vertex_point_size;
+    bool emit_vertex_point_size = false;
+
+    /// Set to `true` to disable workgroup memory zero initialization
+    bool disable_workgroup_init = false;
   };
 
   /// Constructor
diff --git a/src/writer/hlsl/generator.cc b/src/writer/hlsl/generator.cc
index 5dc8693..259f0b6 100644
--- a/src/writer/hlsl/generator.cc
+++ b/src/writer/hlsl/generator.cc
@@ -25,12 +25,14 @@
 Result::~Result() = default;
 Result::Result(const Result&) = default;
 
-Result Generate(const Program* program, const Options&) {
+Result Generate(const Program* program, const Options& options) {
   Result result;
 
   // Run the HLSL sanitizer.
   transform::Hlsl sanitizer;
-  auto output = sanitizer.Run(program);
+  transform::DataMap transform_input;
+  transform_input.Add<transform::Hlsl::Config>(options.disable_workgroup_init);
+  auto output = sanitizer.Run(program, transform_input);
   if (!output.program.IsValid()) {
     result.success = false;
     result.error = output.program.Diagnostics().str();
diff --git a/src/writer/hlsl/generator.h b/src/writer/hlsl/generator.h
index f467348..1ff8d98 100644
--- a/src/writer/hlsl/generator.h
+++ b/src/writer/hlsl/generator.h
@@ -35,7 +35,10 @@
 class GeneratorImpl;
 
 /// Configuration options used for generating HLSL.
-struct Options {};
+struct Options {
+  /// Set to `true` to disable workgroup memory zero initialization
+  bool disable_workgroup_init = false;
+};
 
 /// The result produced when generating HLSL.
 struct Result {
diff --git a/src/writer/msl/generator.cc b/src/writer/msl/generator.cc
index e73541e..7155e94 100644
--- a/src/writer/msl/generator.cc
+++ b/src/writer/msl/generator.cc
@@ -32,7 +32,8 @@
   transform::Msl sanitizer;
   transform::DataMap transform_input;
   transform_input.Add<transform::Msl::Config>(options.buffer_size_ubo_index,
-                                              options.fixed_sample_mask);
+                                              options.fixed_sample_mask,
+                                              options.disable_workgroup_init);
   auto output = sanitizer.Run(program, transform_input);
   if (!output.program.IsValid()) {
     result.success = false;
diff --git a/src/writer/msl/generator.h b/src/writer/msl/generator.h
index 9a182d3..4de190b 100644
--- a/src/writer/msl/generator.h
+++ b/src/writer/msl/generator.h
@@ -39,6 +39,9 @@
   /// The fixed sample mask to combine with fragment shader outputs.
   /// Defaults to 0xFFFFFFFF.
   uint32_t fixed_sample_mask = 0xFFFFFFFF;
+
+  /// Set to `true` to disable workgroup memory zero initialization
+  bool disable_workgroup_init = false;
 };
 
 /// The result produced when generating MSL.
diff --git a/src/writer/spirv/generator.cc b/src/writer/spirv/generator.cc
index 6270e50..ec249ac 100644
--- a/src/writer/spirv/generator.cc
+++ b/src/writer/spirv/generator.cc
@@ -31,7 +31,8 @@
   // Run the SPIR-V sanitizer.
   transform::Spirv sanitizer;
   transform::DataMap transform_input;
-  transform_input.Add<transform::Spirv::Config>(options.emit_vertex_point_size);
+  transform_input.Add<transform::Spirv::Config>(options.emit_vertex_point_size,
+                                                options.disable_workgroup_init);
   auto output = sanitizer.Run(program, transform_input);
   if (!output.program.IsValid()) {
     result.success = false;
diff --git a/src/writer/spirv/generator.h b/src/writer/spirv/generator.h
index 2b06032..aaacd94 100644
--- a/src/writer/spirv/generator.h
+++ b/src/writer/spirv/generator.h
@@ -38,6 +38,9 @@
   /// Set to `true` to generate a PointSize builtin and have it set to 1.0
   /// from all vertex shaders in the module.
   bool emit_vertex_point_size = true;
+
+  /// Set to `true` to disable workgroup memory zero initialization
+  bool disable_workgroup_init = false;
 };
 
 /// The result produced when generating SPIR-V.