transform: Add ZeroInitWorkgroupMemory

Zero initializes all referenced workgroup storage classed variables used by each entry point.

Bug: tint:280
Fixed: tint:911
Change-Id: I3fca26a10f015f08fedef404720bbe6fd7b343a9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/55243
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index b6e4b20..0ff76ce 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -584,6 +584,8 @@
     "transform/vertex_pulling.h",
     "transform/wrap_arrays_in_structs.cc",
     "transform/wrap_arrays_in_structs.h",
+    "transform/zero_init_workgroup_memory.cc",
+    "transform/zero_init_workgroup_memory.h",
     "utils/enum_set.h",
     "utils/get_or_create.h",
     "utils/hash.h",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 5417352..fc5d550 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -309,6 +309,8 @@
   transform/vertex_pulling.h
   transform/wrap_arrays_in_structs.cc
   transform/wrap_arrays_in_structs.h
+  transform/zero_init_workgroup_memory.cc
+  transform/zero_init_workgroup_memory.h
   sem/bool_type.cc
   sem/bool_type.h
   sem/depth_texture_type.cc
@@ -873,6 +875,7 @@
       transform/test_helper.h
       transform/vertex_pulling_test.cc
       transform/wrap_arrays_in_structs_test.cc
+      transform/zero_init_workgroup_memory_test.cc
     )
   endif()
 
diff --git a/src/transform/zero_init_workgroup_memory.cc b/src/transform/zero_init_workgroup_memory.cc
new file mode 100644
index 0000000..53f9618
--- /dev/null
+++ b/src/transform/zero_init_workgroup_memory.cc
@@ -0,0 +1,200 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/zero_init_workgroup_memory.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "src/program_builder.h"
+#include "src/sem/atomic_type.h"
+#include "src/sem/function.h"
+#include "src/sem/variable.h"
+#include "src/utils/get_or_create.h"
+
+namespace tint {
+namespace transform {
+
+// PIMPL state for the ZeroInitWorkgroupMemory transform
+struct ZeroInitWorkgroupMemory::State {
+  /// The clone context
+  CloneContext& ctx;
+  /// The built statements
+  ast::StatementList& stmts;
+
+  /// Zero() generates the statements required to zero initialize the workgroup
+  /// storage expression of type `ty`.
+  /// @param ty the expression type
+  /// @param get_expr a function that builds the AST nodes for the expression
+  void Zero(const sem::Type* ty,
+            const std::function<ast::Expression*()>& get_expr) {
+    if (CanZero(ty)) {
+      auto* var = get_expr();
+      auto* zero_init = ctx.dst->Construct(CreateASTTypeFor(&ctx, ty));
+      stmts.emplace_back(
+          ctx.dst->create<ast::AssignmentStatement>(var, zero_init));
+      return;
+    }
+
+    if (auto* atomic = ty->As<sem::Atomic>()) {
+      auto* zero_init =
+          ctx.dst->Construct(CreateASTTypeFor(&ctx, atomic->Type()));
+      auto* store = ctx.dst->Call("atomicStore", ctx.dst->AddressOf(get_expr()),
+                                  zero_init);
+      stmts.emplace_back(ctx.dst->create<ast::CallStatement>(store));
+      return;
+    }
+
+    if (auto* str = ty->As<sem::Struct>()) {
+      for (auto* member : str->Members()) {
+        auto name = ctx.Clone(member->Declaration()->symbol());
+        Zero(member->Type(),
+             [&] { return ctx.dst->MemberAccessor(get_expr(), name); });
+      }
+      return;
+    }
+
+    if (auto* arr = ty->As<sem::Array>()) {
+      // TODO(bclayton): If array sizes become pipeline-overridable then this
+      // will need to emit code for a loop.
+      // See https://github.com/gpuweb/gpuweb/pull/1792
+      for (size_t i = 0; i < arr->Count(); i++) {
+        Zero(arr->ElemType(), [&] {
+          return ctx.dst->IndexAccessor(get_expr(),
+                                        static_cast<ProgramBuilder::u32>(i));
+        });
+      }
+      return;
+    }
+
+    TINT_UNREACHABLE(ctx.dst->Diagnostics())
+        << "could not zero workgroup type: " << ty->type_name();
+  }
+
+  /// @returns true if the type `ty` can be zeroed with a simple zero-value
+  /// expression in the form of a type constructor without operands. If
+  /// CanZero() returns false, then the type needs to be initialized by
+  /// decomposing the initialization into multiple sub-initializations.
+  /// @param ty the type to inspect
+  static bool CanZero(const sem::Type* ty) {
+    if (ty->Is<sem::Atomic>()) {
+      return false;
+    }
+    if (auto* str = ty->As<sem::Struct>()) {
+      for (auto* member : str->Members()) {
+        if (!CanZero(member->Type())) {
+          return false;
+        }
+      }
+    }
+    if (auto* arr = ty->As<sem::Array>()) {
+      if (!CanZero(arr->ElemType())) {
+        return false;
+      }
+    }
+    return true;
+  }
+};
+
+ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
+
+ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
+
+Output ZeroInitWorkgroupMemory::Run(const Program* in, const DataMap&) {
+  ProgramBuilder out;
+  CloneContext ctx(&out, in);
+
+  auto& sem = ctx.src->Sem();
+
+  for (auto* ast_func : in->AST().Functions()) {
+    if (!ast_func->IsEntryPoint()) {
+      continue;
+    }
+
+    // Generate a list of statements to zero initialize each of the workgroup
+    // storage variables.
+    ast::StatementList stmts;
+    auto* func = sem.Get(ast_func);
+    for (auto* var : func->ReferencedModuleVariables()) {
+      if (var->StorageClass() != ast::StorageClass::kWorkgroup) {
+        continue;
+      }
+      State{ctx, stmts}.Zero(var->Type()->UnwrapRef(), [&] {
+        auto var_name = ctx.Clone(var->Declaration()->symbol());
+        return ctx.dst->Expr(var_name);
+      });
+    }
+
+    if (stmts.empty()) {
+      continue;  // No workgroup variables to initialize.
+    }
+
+    // Scan the entry point for an existing local_invocation_index builtin
+    // parameter
+    ast::Expression* local_index = nullptr;
+    for (auto* param : ast_func->params()) {
+      if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
+              param->decorations())) {
+        if (builtin->value() == ast::Builtin::kLocalInvocationIndex) {
+          local_index = ctx.dst->Expr(ctx.Clone(param->symbol()));
+          break;
+        }
+      }
+
+      if (auto* str = sem.Get(param)->Type()->As<sem::Struct>()) {
+        for (auto* member : str->Members()) {
+          if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
+                  member->Declaration()->decorations())) {
+            if (builtin->value() == ast::Builtin::kLocalInvocationIndex) {
+              auto* param_expr = ctx.dst->Expr(ctx.Clone(param->symbol()));
+              auto member_name = ctx.Clone(member->Declaration()->symbol());
+              local_index = ctx.dst->MemberAccessor(param_expr, member_name);
+              break;
+            }
+          }
+        }
+      }
+    }
+    if (!local_index) {
+      // No existing local index parameter. Append one to the entry point.
+      auto* param = ctx.dst->Param(
+          ctx.dst->Symbols().New("local_invocation_index"), ctx.dst->ty.u32(),
+          {ctx.dst->Builtin(ast::Builtin::kLocalInvocationIndex)});
+      ctx.InsertBack(ast_func->params(), param);
+      local_index = ctx.dst->Expr(param->symbol());
+    }
+
+    // We only want to zero-initialize the workgroup memory with the first
+    // shader invocation. Construct an if statement that holds stmts.
+    // TODO(crbug.com/tint/910): We should attempt to optimize this for arrays.
+    auto* if_zero_local_index = ctx.dst->create<ast::BinaryExpression>(
+        ast::BinaryOp::kEqual, local_index, ctx.dst->Expr(0u));
+    auto* if_stmt = ctx.dst->If(if_zero_local_index, ctx.dst->Block(stmts));
+
+    // Insert this if-statement at the top of the entry point.
+    ctx.InsertFront(ast_func->body()->statements(), if_stmt);
+
+    // Append a single workgroup barrier after the if statement.
+    ctx.InsertFront(
+        ast_func->body()->statements(),
+        ctx.dst->create<ast::CallStatement>(ctx.dst->Call("workgroupBarrier")));
+  }
+
+  ctx.Clone();
+
+  return Output(Program(std::move(out)));
+}
+
+}  // namespace transform
+}  // namespace tint
diff --git a/src/transform/zero_init_workgroup_memory.h b/src/transform/zero_init_workgroup_memory.h
new file mode 100644
index 0000000..bf846c7
--- /dev/null
+++ b/src/transform/zero_init_workgroup_memory.h
@@ -0,0 +1,47 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
+#define SRC_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
+
+#include "src/transform/transform.h"
+
+namespace tint {
+namespace transform {
+
+/// ZeroInitWorkgroupMemory is a transform that injects code at the top of entry
+/// points to zero-initialize workgroup memory used by that entry point (and all
+/// transitive functions called by that entry point)
+class ZeroInitWorkgroupMemory : public Transform {
+ public:
+  /// Constructor
+  ZeroInitWorkgroupMemory();
+
+  /// Destructor
+  ~ZeroInitWorkgroupMemory() override;
+
+  /// Runs the transform on `program`, returning the transformation result.
+  /// @param program the source program to transform
+  /// @param data optional extra transform-specific input data
+  /// @returns the transformation result
+  Output Run(const Program* program, const DataMap& data = {}) override;
+
+ private:
+  struct State;
+};
+
+}  // namespace transform
+}  // namespace tint
+
+#endif  // SRC_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
diff --git a/src/transform/zero_init_workgroup_memory_test.cc b/src/transform/zero_init_workgroup_memory_test.cc
new file mode 100644
index 0000000..ab1b305
--- /dev/null
+++ b/src/transform/zero_init_workgroup_memory_test.cc
@@ -0,0 +1,563 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/transform/zero_init_workgroup_memory.h"
+
+#include <utility>
+
+#include "src/transform/test_helper.h"
+
+namespace tint {
+namespace transform {
+namespace {
+
+using ZeroInitWorkgroupMemoryTest = TransformTest;
+
+TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) {
+  auto* src = "";
+  auto* expect = src;
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, NoWorkgroupVars) {
+  auto* src = R"(
+var<private> v : i32;
+
+fn f() {
+  v = 1;
+}
+)";
+  auto* expect = src;
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars) {
+  auto* src = R"(
+var<workgroup> a : i32;
+
+var<workgroup> b : i32;
+
+var<workgroup> c : i32;
+
+fn unreferenced() {
+  b = c;
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+}
+)";
+  auto* expect = src;
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex) {
+  auto* src = R"(
+var<workgroup> v : i32;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] idx : u32) {
+  ignore(v); // Initialization should be inserted above this statement
+}
+)";
+  auto* expect = R"(
+var<workgroup> v : i32;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] idx : u32) {
+  if ((idx == 0u)) {
+    v = i32();
+  }
+  workgroupBarrier();
+  ignore(v);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest,
+       SingleWorkgroupVar_ExistingLocalIndexInStruct) {
+  auto* src = R"(
+var<workgroup> v : i32;
+
+struct Params {
+  [[builtin(local_invocation_index)]] idx : u32;
+};
+
+[[stage(compute), workgroup_size(1)]]
+fn f(params : Params) {
+  ignore(v); // Initialization should be inserted above this statement
+}
+)";
+  auto* expect = R"(
+var<workgroup> v : i32;
+
+struct Params {
+  [[builtin(local_invocation_index)]]
+  idx : u32;
+};
+
+[[stage(compute), workgroup_size(1)]]
+fn f(params : Params) {
+  if ((params.idx == 0u)) {
+    v = i32();
+  }
+  workgroupBarrier();
+  ignore(v);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex) {
+  auto* src = R"(
+var<workgroup> v : i32;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+  ignore(v); // Initialization should be inserted above this statement
+}
+)";
+  auto* expect = R"(
+var<workgroup> v : i32;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
+  if ((local_invocation_index == 0u)) {
+    v = i32();
+  }
+  workgroupBarrier();
+  ignore(v);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_ExistingLocalIndex) {
+  auto* src = R"(
+struct S {
+  x : i32;
+  y : array<i32, 8>;
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] idx : u32) {
+  ignore(a); // Initialization should be inserted above this statement
+  ignore(b);
+  ignore(c);
+}
+)";
+  auto* expect = R"(
+struct S {
+  x : i32;
+  y : array<i32, 8>;
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] idx : u32) {
+  if ((idx == 0u)) {
+    a = i32();
+    b = S();
+    c = array<S, 32>();
+  }
+  workgroupBarrier();
+  ignore(a);
+  ignore(b);
+  ignore(c);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex) {
+  auto* src = R"(
+struct S {
+  x : i32;
+  y : array<i32, 8>;
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>) {
+  ignore(a); // Initialization should be inserted above this statement
+  ignore(b);
+  ignore(c);
+}
+)";
+  auto* expect = R"(
+struct S {
+  x : i32;
+  y : array<i32, 8>;
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index : u32) {
+  if ((local_invocation_index == 0u)) {
+    a = i32();
+    b = S();
+    c = array<S, 32>();
+  }
+  workgroupBarrier();
+  ignore(a);
+  ignore(b);
+  ignore(c);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints) {
+  auto* src = R"(
+struct S {
+  x : i32;
+  y : array<i32, 8>;
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f1() {
+  ignore(a); // Initialization should be inserted above this statement
+  ignore(c);
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f2([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>) {
+  ignore(b); // Initialization should be inserted above this statement
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f3() {
+  ignore(c); // Initialization should be inserted above this statement
+  ignore(a);
+}
+)";
+  auto* expect = R"(
+struct S {
+  x : i32;
+  y : array<i32, 8>;
+};
+
+var<workgroup> a : i32;
+
+var<workgroup> b : S;
+
+var<workgroup> c : array<S, 32>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f1([[builtin(local_invocation_index)]] local_invocation_index : u32) {
+  if ((local_invocation_index == 0u)) {
+    a = i32();
+    c = array<S, 32>();
+  }
+  workgroupBarrier();
+  ignore(a);
+  ignore(c);
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f2([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index_1 : u32) {
+  if ((local_invocation_index_1 == 0u)) {
+    b = S();
+  }
+  workgroupBarrier();
+  ignore(b);
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f3([[builtin(local_invocation_index)]] local_invocation_index_2 : u32) {
+  if ((local_invocation_index_2 == 0u)) {
+    c = array<S, 32>();
+    a = i32();
+  }
+  workgroupBarrier();
+  ignore(c);
+  ignore(a);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage) {
+  auto* src = R"(
+var<workgroup> v : i32;
+
+fn use_v() {
+  ignore(v);
+}
+
+fn call_use_v() {
+  use_v();
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] idx : u32) {
+  call_use_v(); // Initialization should be inserted above this statement
+}
+)";
+  auto* expect = R"(
+var<workgroup> v : i32;
+
+fn use_v() {
+  ignore(v);
+}
+
+fn call_use_v() {
+  use_v();
+}
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] idx : u32) {
+  if ((idx == 0u)) {
+    v = i32();
+  }
+  workgroupBarrier();
+  call_use_v();
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics) {
+  auto* src = R"(
+var<workgroup> i : atomic<i32>;
+var<workgroup> u : atomic<u32>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+  ignore(i); // Initialization should be inserted above this statement
+  ignore(u);
+}
+)";
+  auto* expect = R"(
+var<workgroup> i : atomic<i32>;
+
+var<workgroup> u : atomic<u32>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
+  if ((local_invocation_index == 0u)) {
+    atomicStore(&(i), i32());
+    atomicStore(&(u), u32());
+  }
+  workgroupBarrier();
+  ignore(i);
+  ignore(u);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics) {
+  auto* src = R"(
+struct S {
+  a : i32;
+  i : atomic<i32>;
+  b : f32;
+  u : atomic<u32>;
+  c : u32;
+};
+
+var<workgroup> w : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+  ignore(w); // Initialization should be inserted above this statement
+}
+)";
+  auto* expect = R"(
+struct S {
+  a : i32;
+  i : atomic<i32>;
+  b : f32;
+  u : atomic<u32>;
+  c : u32;
+};
+
+var<workgroup> w : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
+  if ((local_invocation_index == 0u)) {
+    w.a = i32();
+    atomicStore(&(w.i), i32());
+    w.b = f32();
+    atomicStore(&(w.u), u32());
+    w.c = u32();
+  }
+  workgroupBarrier();
+  ignore(w);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics) {
+  auto* src = R"(
+var<workgroup> w : array<atomic<u32>, 4>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+  ignore(w); // Initialization should be inserted above this statement
+}
+)";
+  auto* expect = R"(
+var<workgroup> w : array<atomic<u32>, 4>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
+  if ((local_invocation_index == 0u)) {
+    atomicStore(&(w[0u]), u32());
+    atomicStore(&(w[1u]), u32());
+    atomicStore(&(w[2u]), u32());
+    atomicStore(&(w[3u]), u32());
+  }
+  workgroupBarrier();
+  ignore(w);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics) {
+  auto* src = R"(
+struct S {
+  a : i32;
+  i : atomic<i32>;
+  b : f32;
+  u : atomic<u32>;
+  c : u32;
+};
+
+var<workgroup> w : array<S, 4>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f() {
+  ignore(w); // Initialization should be inserted above this statement
+}
+)";
+  auto* expect = R"(
+struct S {
+  a : i32;
+  i : atomic<i32>;
+  b : f32;
+  u : atomic<u32>;
+  c : u32;
+};
+
+var<workgroup> w : array<S, 4>;
+
+[[stage(compute), workgroup_size(1)]]
+fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
+  if ((local_invocation_index == 0u)) {
+    w[0u].a = i32();
+    atomicStore(&(w[0u].i), i32());
+    w[0u].b = f32();
+    atomicStore(&(w[0u].u), u32());
+    w[0u].c = u32();
+    w[1u].a = i32();
+    atomicStore(&(w[1u].i), i32());
+    w[1u].b = f32();
+    atomicStore(&(w[1u].u), u32());
+    w[1u].c = u32();
+    w[2u].a = i32();
+    atomicStore(&(w[2u].i), i32());
+    w[2u].b = f32();
+    atomicStore(&(w[2u].u), u32());
+    w[2u].c = u32();
+    w[3u].a = i32();
+    atomicStore(&(w[3u].i), i32());
+    w[3u].b = f32();
+    atomicStore(&(w[3u].u), u32());
+    w[3u].c = u32();
+  }
+  workgroupBarrier();
+  ignore(w);
+}
+)";
+
+  auto got = Run<ZeroInitWorkgroupMemory>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+}  // namespace
+}  // namespace transform
+}  // namespace tint
diff --git a/test/BUILD.gn b/test/BUILD.gn
index d4df320..04d781d 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -292,6 +292,7 @@
     "../src/transform/transform_test.cc",
     "../src/transform/vertex_pulling_test.cc",
     "../src/transform/wrap_arrays_in_structs_test.cc",
+    "../src/transform/zero_init_workgroup_memory_test.cc",
     "../src/utils/enum_set_test.cc",
     "../src/utils/get_or_create_test.cc",
     "../src/utils/hash_test.cc",