[ir] Add ZeroInitWorkgroupMemory transform

Use it in the SPIR-V writer.

Bug: tint:1718, tint:1906

Change-Id: Ie1fe7491518aa0d793f45eda52f6a82bcd5485e3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/152463
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/transform/BUILD.bazel b/src/tint/lang/core/ir/transform/BUILD.bazel
index 7216985..4388198 100644
--- a/src/tint/lang/core/ir/transform/BUILD.bazel
+++ b/src/tint/lang/core/ir/transform/BUILD.bazel
@@ -37,6 +37,7 @@
     "robustness.cc",
     "shader_io.cc",
     "std140.cc",
+    "zero_init_workgroup_memory.cc",
   ],
   hdrs = [
     "add_empty_entry_point.h",
@@ -50,6 +51,7 @@
     "robustness.h",
     "shader_io.h",
     "std140.h",
+    "zero_init_workgroup_memory.h",
   ],
   deps = [
     "//src/tint/api/common",
@@ -92,6 +94,7 @@
     "multiplanar_external_texture_test.cc",
     "robustness_test.cc",
     "std140_test.cc",
+    "zero_init_workgroup_memory_test.cc",
   ],
   deps = [
     "//src/tint/api/common",
diff --git a/src/tint/lang/core/ir/transform/BUILD.cmake b/src/tint/lang/core/ir/transform/BUILD.cmake
index f474b48..5e772df 100644
--- a/src/tint/lang/core/ir/transform/BUILD.cmake
+++ b/src/tint/lang/core/ir/transform/BUILD.cmake
@@ -48,6 +48,8 @@
   lang/core/ir/transform/shader_io.h
   lang/core/ir/transform/std140.cc
   lang/core/ir/transform/std140.h
+  lang/core/ir/transform/zero_init_workgroup_memory.cc
+  lang/core/ir/transform/zero_init_workgroup_memory.h
 )
 
 tint_target_add_dependencies(tint_lang_core_ir_transform lib
@@ -90,6 +92,7 @@
   lang/core/ir/transform/multiplanar_external_texture_test.cc
   lang/core/ir/transform/robustness_test.cc
   lang/core/ir/transform/std140_test.cc
+  lang/core/ir/transform/zero_init_workgroup_memory_test.cc
 )
 
 tint_target_add_dependencies(tint_lang_core_ir_transform_test test
diff --git a/src/tint/lang/core/ir/transform/BUILD.gn b/src/tint/lang/core/ir/transform/BUILD.gn
index 82e462b..cf99aba 100644
--- a/src/tint/lang/core/ir/transform/BUILD.gn
+++ b/src/tint/lang/core/ir/transform/BUILD.gn
@@ -53,6 +53,8 @@
     "shader_io.h",
     "std140.cc",
     "std140.h",
+    "zero_init_workgroup_memory.cc",
+    "zero_init_workgroup_memory.h",
   ]
   deps = [
     "${tint_src_dir}/api/common",
@@ -93,6 +95,7 @@
       "multiplanar_external_texture_test.cc",
       "robustness_test.cc",
       "std140_test.cc",
+      "zero_init_workgroup_memory_test.cc",
     ]
     deps = [
       "${tint_src_dir}:gmock_and_gtest",
diff --git a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
new file mode 100644
index 0000000..cd6c589
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
@@ -0,0 +1,403 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
+
+#include <map>
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/utils/containers/reverse.h"
+
+using namespace tint::core::fluent_types;     // NOLINT
+using namespace tint::core::number_suffixes;  // NOLINT
+
+namespace tint::core::ir::transform {
+
+namespace {
+
+/// PIMPL state for the transform.
+struct State {
+    /// The IR module.
+    Module& ir;
+
+    /// The IR builder.
+    Builder b{ir};
+
+    /// The type manager.
+    core::type::Manager& ty{ir.Types()};
+
+    /// VarSet is a hash set of workgroup variables.
+    using VarSet = Hashset<Var*, 8>;
+
+    /// A map from variable to an ID used for sorting.
+    Hashmap<Var*, uint32_t, 8> var_to_id{};
+
+    /// A map from blocks to their directly referenced workgroup variables.
+    Hashmap<Block*, VarSet, 64> block_to_direct_vars{};
+
+    /// A map from functions to their transitively referenced workgroup variables.
+    Hashmap<Function*, VarSet, 8> function_to_transitive_vars{};
+
+    /// ArrayIndex represents a required array index for an access instruction.
+    struct ArrayIndex {
+        /// The size of the array that will be indexed.
+        uint32_t count = 0u;
+    };
+
+    /// Index represents an index for an access instruction, which is either a constant value or
+    /// an array index that will be dynamically calculated from an array size.
+    using Index = std::variant<uint32_t, ArrayIndex>;
+
+    /// Store describes a store to a sub-element of a workgroup variable.
+    struct Store {
+        /// The workgroup variable.
+        Var* var = nullptr;
+        /// The store type of the element.
+        const type::Type* store_type = nullptr;
+        /// The list of index operands to get to the element.
+        Vector<Index, 4> indices;
+    };
+
+    /// StoreList is a list of `Store` descriptors.
+    using StoreList = Vector<Store, 8>;
+
+    /// StoreMap is a map from iteration count to a list of `Store` descriptors.
+    using StoreMap = Hashmap<uint32_t, StoreList, 8>;
+
+    /// Process the module.
+    void Process() {
+        if (!ir.root_block) {
+            return;
+        }
+
+        // Loop over module-scope variables, looking for workgroup variables.
+        uint32_t next_id = 0;
+        for (auto inst : *ir.root_block) {
+            if (auto* var = inst->As<Var>()) {
+                auto* ptr = var->Result()->Type()->As<core::type::Pointer>();
+                if (ptr && ptr->AddressSpace() == core::AddressSpace::kWorkgroup) {
+                    // Record the usage of the variable for each block that references it.
+                    var->Result()->ForEachUse([&](const Usage& use) {
+                        block_to_direct_vars.GetOrZero(use.instruction->Block())->Add(var);
+                    });
+                    var_to_id.Add(var, next_id++);
+                }
+            }
+        }
+
+        // Process each entry point function.
+        for (auto* func : ir.functions) {
+            if (func->Stage() == Function::PipelineStage::kCompute) {
+                ProcessEntryPoint(func);
+            }
+        }
+    }
+
+    /// Process an entry point function to zero-initialize the workgroup variables that it uses.
+    /// @param func the entry point function
+    void ProcessEntryPoint(Function* func) {
+        // Get list of transitively referenced workgroup variables.
+        auto vars = GetReferencedVars(func);
+        if (vars.IsEmpty()) {
+            return;
+        }
+
+        // Sort the variables to get deterministic output in tests.
+        auto sorted_vars = vars.Vector();
+        sorted_vars.Sort([&](Var* first, Var* second) {
+            return *var_to_id.Get(first) < *var_to_id.Get(second);
+        });
+
+        // Build list of store descriptors for all workgroup variables.
+        StoreMap stores;
+        for (auto* var : sorted_vars) {
+            PrepareStores(var, var->Result()->Type()->UnwrapPtr(), 1, {}, stores);
+        }
+
+        // Sort the iteration counts to get deterministic output in tests.
+        auto sorted_iteration_counts = stores.Keys();
+        sorted_iteration_counts.Sort();
+
+        // Capture the first instruction of the function.
+        // All new instructions will be inserted before this.
+        auto* function_start = func->Block()->Front();
+
+        // Get the local invocation index and the linearized workgroup size.
+        auto* local_index = GetLocalInvocationIndex(func);
+        auto wgsizes = func->WorkgroupSize().value();
+        auto wgsize = wgsizes[0] * wgsizes[1] * wgsizes[2];
+
+        // Insert instructions to zero-initialize every variable.
+        b.InsertBefore(function_start, [&] {
+            for (auto count : sorted_iteration_counts) {
+                auto element_stores = stores.Get(count);
+                if (count == 1u) {
+                    // Make the first invocation in the group perform all of the non-arrayed stores.
+                    auto* ifelse = b.If(b.Equal(ty.bool_(), local_index, 0_u));
+                    b.Append(ifelse->True(), [&] {
+                        for (auto& store : *element_stores) {
+                            GenerateStore(store, count, b.Constant(0_u));
+                        }
+                        b.ExitIf(ifelse);
+                    });
+                } else {
+                    // Use a loop for arrayed stores.
+                    GenerateZeroingLoop(local_index, count, wgsize, *element_stores);
+                }
+            }
+            b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+        });
+    }
+
+    /// Get the set of workgroup variables transitively referenced by @p func.
+    /// @param func the function
+    /// @returns the set of transitively referenced workgroup variables
+    VarSet GetReferencedVars(Function* func) {
+        return function_to_transitive_vars.GetOrCreate(func, [&] {
+            VarSet vars;
+            GetReferencedVars(func->Block(), vars);
+            return vars;
+        });
+    }
+
+    /// Get the set of workgroup variables transitively referenced by @p block.
+    /// @param block the block
+    /// @param vars the set of transitively referenced workgroup variables to populate
+    void GetReferencedVars(Block* block, VarSet& vars) {
+        // Add directly referenced vars.
+        if (auto itr = block_to_direct_vars.Find(block)) {
+            for (auto* var : *itr) {
+                vars.Add(var);
+            }
+        }
+
+        // Loop over instructions in the block.
+        for (auto* inst : *block) {
+            tint::Switch(
+                inst,
+                [&](UserCall* call) {
+                    // Get variables referenced by a function called from this block.
+                    auto callee_vars = GetReferencedVars(call->Target());
+                    for (auto* var : callee_vars) {
+                        vars.Add(var);
+                    }
+                },
+                [&](ControlInstruction* ctrl) {
+                    // Recurse into control instructions and gather their referenced vars.
+                    ctrl->ForeachBlock([&](Block* blk) { GetReferencedVars(blk, vars); });
+                });
+        }
+    }
+
+    /// Recursively generate store descriptors for a workgroup variable.
+    /// Determines the combined array iteration count of each inner element.
+    /// @param var the workgroup variable
+    /// @param type the current element type
+    /// @param iteration_count the iteration count of this inner element of the variable
+    /// @param indices the access indices needed to get to this element
+    /// @param stores the map of stores to populate
+    void PrepareStores(Var* var,
+                       const type::Type* type,
+                       uint32_t iteration_count,
+                       Vector<Index, 4> indices,
+                       StoreMap& stores) {
+        // If this type can be trivially zeroed, store to the whole element.
+        if (CanTriviallyZero(type)) {
+            stores.GetOrZero(iteration_count)->Push(Store{var, type, indices});
+            return;
+        }
+
+        tint::Switch(
+            type,
+            [&](const type::Array* arr) {
+                // Add an array index to the list and recurse into the element type.
+                TINT_ASSERT(arr->ConstantCount());
+                auto count = arr->ConstantCount().value();
+                auto new_indices = indices;
+                if (count > 1) {
+                    new_indices.Push(ArrayIndex{count});
+                } else {
+                    new_indices.Push(0u);
+                }
+                PrepareStores(var, arr->ElemType(), iteration_count * count, new_indices, stores);
+            },
+            [&](const type::Atomic*) {
+                stores.GetOrZero(iteration_count)->Push(Store{var, type, indices});
+            },
+            [&](const type::Struct* str) {
+                for (auto* member : str->Members()) {
+                    // Add the member index to the index list and recurse into its type.
+                    auto new_indices = indices;
+                    new_indices.Push(member->Index());
+                    PrepareStores(var, member->Type(), iteration_count, new_indices, stores);
+                }
+            },
+            [&](Default) { TINT_UNREACHABLE(); });
+    }
+
+    /// Get or inject an entry point builtin for the local invocation index.
+    /// @param func the entry point function
+    /// @returns the local invocation index builtin
+    Value* GetLocalInvocationIndex(Function* func) {
+        // Look for an existing local_invocation_index builtin parameter.
+        for (auto* param : func->Params()) {
+            if (auto* str = param->Type()->As<type::Struct>()) {
+                // Check each member for the local invocation index builtin attribute.
+                for (auto* member : str->Members()) {
+                    if (member->Attributes().builtin && member->Attributes().builtin.value() ==
+                                                            BuiltinValue::kLocalInvocationIndex) {
+                        auto* access = b.Access(ty.u32(), param, u32(member->Index()));
+                        access->InsertBefore(func->Block()->Front());
+                        return access->Result();
+                    }
+                }
+            } else {
+                // Check if the parameter is the local invocation index.
+                if (param->Builtin() &&
+                    param->Builtin().value() == FunctionParam::Builtin::kLocalInvocationIndex) {
+                    return param;
+                }
+            }
+        }
+
+        // No local invocation index was found, so add one to the parameter list and use that.
+        Vector<FunctionParam*, 4> params = func->Params();
+        auto* param = b.FunctionParam("tint_local_index", ty.u32());
+        param->SetBuiltin(FunctionParam::Builtin::kLocalInvocationIndex);
+        params.Push(param);
+        func->SetParams(params);
+        return param;
+    }
+
+    /// Generate the store instruction for a given store descriptor.
+    /// @param store the store descriptor
+    /// @param total_count the total number of elements that will be zeroed
+    /// @param linear_index the linear index of the single element that will be zeroed
+    void GenerateStore(const Store& store, uint32_t total_count, Value* linear_index) {
+        auto* to = store.var->Result();
+        if (!store.indices.IsEmpty()) {
+            // Build the access indices to get to the target element.
+            // We walk backwards along the index list so that adjacent invocation store to
+            // adjacent array elements.
+            uint32_t count = 1;
+            Vector<Value*, 4> indices;
+            for (auto idx : Reverse(store.indices)) {
+                if (std::holds_alternative<ArrayIndex>(idx)) {
+                    // Array indices are computed from the linear index based on the size of the
+                    // array and the size of the sub-arrays that have already been indexed.
+                    auto array_index = std::get<ArrayIndex>(idx);
+                    Value* index = linear_index;
+                    if (count > 1) {
+                        index = b.Divide(ty.u32(), index, u32(count))->Result();
+                    }
+                    if (total_count > count * array_index.count) {
+                        index = b.Modulo(ty.u32(), index, u32(array_index.count))->Result();
+                    }
+                    indices.Push(index);
+                    count *= array_index.count;
+                } else {
+                    // Constant indices are added to the list unmodified.
+                    indices.Push(b.Constant(u32(std::get<uint32_t>(idx))));
+                }
+            }
+            indices.Reverse();
+            to = b.Access(ty.ptr(workgroup, store.store_type), to, indices)->Result();
+        }
+
+        // Generate the store instruction.
+        if (auto* atomic = store.store_type->As<type::Atomic>()) {
+            auto* zero = b.Constant(ir.constant_values.Zero(atomic->Type()));
+            b.Call(ty.void_(), core::BuiltinFn::kAtomicStore, to, zero);
+        } else {
+            auto* zero = b.Constant(ir.constant_values.Zero(store.store_type));
+            b.Store(to, zero);
+        }
+    }
+
+    /// Generate a loop for a list of stores with the same iteration count.
+    /// @param local_index the local invocation index
+    /// @param total_count the number of iterations needed to store to all elements
+    /// @param wgsize the linearized workgroup size
+    /// @param stores the list of store descriptors
+    void GenerateZeroingLoop(Value* local_index,
+                             uint32_t total_count,
+                             uint32_t wgsize,
+                             const StoreList& stores) {
+        // The loop is equivalent to:
+        //   for (var idx = local_index; idx < linear_iteration_count; idx += wgsize) {
+        //     <store to elements at `idx`>
+        //   }
+        auto* loop = b.Loop();
+        auto* index = b.BlockParam(ty.u32());
+        loop->Body()->SetParams({index});
+        b.Append(loop->Initializer(), [&] {  //
+            b.NextIteration(loop, local_index);
+        });
+        b.Append(loop->Body(), [&] {
+            // Exit the loop when the iteration count has been exceeded.
+            auto* gt_max = b.GreaterThan(ty.bool_(), index, u32(total_count - 1u));
+            auto* ifelse = b.If(gt_max);
+            b.Append(ifelse->True(), [&] {  //
+                b.ExitLoop(loop);
+            });
+
+            // Insert all of the store instructions.
+            for (auto& store : stores) {
+                GenerateStore(store, total_count, index);
+            }
+
+            b.Continue(loop);
+        });
+        b.Append(loop->Continuing(), [&] {  //
+            // Increment the loop index by linearized workgroup size.
+            b.NextIteration(loop, b.Add(ty.u32(), index, u32(wgsize)));
+        });
+    }
+
+    /// Check if a type can be efficiently zeroed with a single store. Returns `false` if there are
+    /// any nested arrays or atomics.
+    /// @param type the type to inspect
+    /// @returns true if a variable with store type @p ty can be efficiently zeroed
+    bool CanTriviallyZero(const core::type::Type* type) {
+        if (type->IsAnyOf<core::type::Atomic, core::type::Array>()) {
+            return false;
+        }
+        if (auto* str = type->As<core::type::Struct>()) {
+            for (auto* member : str->Members()) {
+                if (!CanTriviallyZero(member->Type())) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+};
+
+}  // namespace
+
+Result<SuccessType> ZeroInitWorkgroupMemory(Module& ir) {
+    auto result = ValidateAndDumpIfNeeded(ir, "ZeroInitWorkgroupMemory transform");
+    if (!result) {
+        return result;
+    }
+
+    State{ir}.Process();
+
+    return Success;
+}
+
+}  // namespace tint::core::ir::transform
diff --git a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h
new file mode 100644
index 0000000..02306c5
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h
@@ -0,0 +1,37 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_CORE_IR_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
+#define SRC_TINT_LANG_CORE_IR_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
+
+#include <string>
+
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}
+
+namespace tint::core::ir::transform {
+
+/// ZeroInitWorkgroupMemory is a transform that injects code at the top of each entry point to
+/// zero-initialize workgroup memory used by that entry point.
+/// @param module the module to transform
+/// @returns success or failure
+Result<SuccessType> ZeroInitWorkgroupMemory(Module& module);
+
+}  // namespace tint::core::ir::transform
+
+#endif  // SRC_TINT_LANG_CORE_IR_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
diff --git a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory_test.cc b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory_test.cc
new file mode 100644
index 0000000..d2c556b
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory_test.cc
@@ -0,0 +1,1940 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+
+namespace tint::core::ir::transform {
+namespace {
+
+using namespace tint::core::fluent_types;     // NOLINT
+using namespace tint::core::number_suffixes;  // NOLINT
+
+class IR_ZeroInitWorkgroupMemoryTest : public TransformTest {
+  protected:
+    Function* MakeEntryPoint(const char* name,
+                             uint32_t wgsize_x,
+                             uint32_t wgsize_y,
+                             uint32_t wgsize_z) {
+        auto* func = b.Function(name, ty.void_(), Function::PipelineStage::kCompute);
+        func->SetWorkgroupSize(wgsize_x, wgsize_y, wgsize_z);
+        return func;
+    }
+
+    Var* MakeVar(const char* name, const type::Type* store_type) {
+        auto* var = b.Var(name, ty.ptr(workgroup, store_type));
+        b.RootBlock()->Append(var);
+        return var;
+    }
+};
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NoRootBlock) {
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Return(func);
+    });
+
+    auto* expect = R"(
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+  %b1 = block {
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, WorkgroupVarUnused) {
+    MakeVar("wgvar", ty.i32());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, i32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = src;
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarBool) {
+    auto* var = MakeVar("wgvar", ty.bool_());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:bool = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, false
+        exit_if  # if_1
+      }
+    }
+    %5:void = workgroupBarrier
+    %6:bool = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarI32) {
+    auto* var = MakeVar("wgvar", ty.i32());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, i32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:i32 = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, i32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, 0i
+        exit_if  # if_1
+      }
+    }
+    %5:void = workgroupBarrier
+    %6:i32 = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarU32) {
+    auto* var = MakeVar("wgvar", ty.u32());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, u32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:u32 = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, u32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, 0u
+        exit_if  # if_1
+      }
+    }
+    %5:void = workgroupBarrier
+    %6:u32 = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarF32) {
+    auto* var = MakeVar("wgvar", ty.f32());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, f32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:f32 = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, f32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, 0.0f
+        exit_if  # if_1
+      }
+    }
+    %5:void = workgroupBarrier
+    %6:f32 = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarF16) {
+    auto* var = MakeVar("wgvar", ty.f16());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, f16, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:f16 = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, f16, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, 0.0h
+        exit_if  # if_1
+      }
+    }
+    %5:void = workgroupBarrier
+    %6:f16 = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, AtomicI32) {
+    auto* var = MakeVar("wgvar", ty.atomic<i32>());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Call(ty.i32(), core::BuiltinFn::kAtomicLoad, var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, atomic<i32>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:i32 = atomicLoad %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, atomic<i32>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        %5:void = atomicStore %wgvar, 0i
+        exit_if  # if_1
+      }
+    }
+    %6:void = workgroupBarrier
+    %7:i32 = atomicLoad %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, AtomicU32) {
+    auto* var = MakeVar("wgvar", ty.atomic<u32>());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Call(ty.u32(), core::BuiltinFn::kAtomicLoad, var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, atomic<u32>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:u32 = atomicLoad %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, atomic<u32>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        %5:void = atomicStore %wgvar, 0u
+        exit_if  # if_1
+      }
+    }
+    %6:void = workgroupBarrier
+    %7:u32 = atomicLoad %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ArrayOfI32) {
+    auto* var = MakeVar("wgvar", ty.array<i32, 4>());
+
+    auto* func = MakeEntryPoint("main", 11, 2, 3);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<i32, 4>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func():void -> %b2 {
+  %b2 = block {
+    %3:array<i32, 4> = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<i32, 4>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    loop [i: %b3, b: %b4, c: %b5] {  # loop_1
+      %b3 = block {  # initializer
+        next_iteration %b4 %tint_local_index
+      }
+      %b4 = block (%4:u32) {  # body
+        %5:bool = gt %4:u32, 3u
+        if %5 [t: %b6] {  # if_1
+          %b6 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %6:ptr<workgroup, i32, read_write> = access %wgvar, %4:u32
+        store %6, 0i
+        continue %b5
+      }
+      %b5 = block {  # continuing
+        %7:u32 = add %4:u32, 66u
+        next_iteration %b4 %7
+      }
+    }
+    %8:void = workgroupBarrier
+    %9:array<i32, 4> = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ArrayOfArrayOfU32) {
+    auto* var = MakeVar("wgvar", ty.array(ty.array<u32, 5>(), 7));
+
+    auto* func = MakeEntryPoint("main", 11, 2, 3);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<u32, 5>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func():void -> %b2 {
+  %b2 = block {
+    %3:array<array<u32, 5>, 7> = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<u32, 5>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    loop [i: %b3, b: %b4, c: %b5] {  # loop_1
+      %b3 = block {  # initializer
+        next_iteration %b4 %tint_local_index
+      }
+      %b4 = block (%4:u32) {  # body
+        %5:bool = gt %4:u32, 34u
+        if %5 [t: %b6] {  # if_1
+          %b6 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %6:u32 = mod %4:u32, 5u
+        %7:u32 = div %4:u32, 5u
+        %8:ptr<workgroup, u32, read_write> = access %wgvar, %7, %6
+        store %8, 0u
+        continue %b5
+      }
+      %b5 = block {  # continuing
+        %9:u32 = add %4:u32, 66u
+        next_iteration %b4 %9
+      }
+    }
+    %10:void = workgroupBarrier
+    %11:array<array<u32, 5>, 7> = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ArrayOfArrayOfArray) {
+    auto* var = MakeVar("wgvar", ty.array(ty.array(ty.array<i32, 7>(), 5), 3));
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<array<i32, 7>, 5>, 3>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:array<array<array<i32, 7>, 5>, 3> = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<array<i32, 7>, 5>, 3>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    loop [i: %b3, b: %b4, c: %b5] {  # loop_1
+      %b3 = block {  # initializer
+        next_iteration %b4 %tint_local_index
+      }
+      %b4 = block (%4:u32) {  # body
+        %5:bool = gt %4:u32, 104u
+        if %5 [t: %b6] {  # if_1
+          %b6 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %6:u32 = mod %4:u32, 7u
+        %7:u32 = div %4:u32, 7u
+        %8:u32 = mod %7, 5u
+        %9:u32 = div %4:u32, 35u
+        %10:ptr<workgroup, i32, read_write> = access %wgvar, %9, %8, %6
+        store %10, 0i
+        continue %b5
+      }
+      %b5 = block {  # continuing
+        %11:u32 = add %4:u32, 1u
+        next_iteration %b4 %11
+      }
+    }
+    %12:void = workgroupBarrier
+    %13:array<array<array<i32, 7>, 5>, 3> = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedArrayInnerSizeOne) {
+    auto* var = MakeVar("wgvar", ty.array(ty.array(ty.array<i32, 1>(), 5), 3));
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<array<i32, 1>, 5>, 3>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:array<array<array<i32, 1>, 5>, 3> = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<array<i32, 1>, 5>, 3>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    loop [i: %b3, b: %b4, c: %b5] {  # loop_1
+      %b3 = block {  # initializer
+        next_iteration %b4 %tint_local_index
+      }
+      %b4 = block (%4:u32) {  # body
+        %5:bool = gt %4:u32, 14u
+        if %5 [t: %b6] {  # if_1
+          %b6 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %6:u32 = mod %4:u32, 5u
+        %7:u32 = div %4:u32, 5u
+        %8:ptr<workgroup, i32, read_write> = access %wgvar, %7, %6, 0u
+        store %8, 0i
+        continue %b5
+      }
+      %b5 = block {  # continuing
+        %9:u32 = add %4:u32, 1u
+        next_iteration %b4 %9
+      }
+    }
+    %10:void = workgroupBarrier
+    %11:array<array<array<i32, 1>, 5>, 3> = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedArrayMiddleSizeOne) {
+    auto* var = MakeVar("wgvar", ty.array(ty.array(ty.array<i32, 3>(), 1), 5));
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<array<i32, 3>, 1>, 5>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:array<array<array<i32, 3>, 1>, 5> = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<array<i32, 3>, 1>, 5>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    loop [i: %b3, b: %b4, c: %b5] {  # loop_1
+      %b3 = block {  # initializer
+        next_iteration %b4 %tint_local_index
+      }
+      %b4 = block (%4:u32) {  # body
+        %5:bool = gt %4:u32, 14u
+        if %5 [t: %b6] {  # if_1
+          %b6 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %6:u32 = mod %4:u32, 3u
+        %7:u32 = div %4:u32, 3u
+        %8:ptr<workgroup, i32, read_write> = access %wgvar, %7, 0u, %6
+        store %8, 0i
+        continue %b5
+      }
+      %b5 = block {  # continuing
+        %9:u32 = add %4:u32, 1u
+        next_iteration %b4 %9
+      }
+    }
+    %10:void = workgroupBarrier
+    %11:array<array<array<i32, 3>, 1>, 5> = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedArrayOuterSizeOne) {
+    auto* var = MakeVar("wgvar", ty.array(ty.array(ty.array<i32, 3>(), 5), 1));
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<array<i32, 3>, 5>, 1>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:array<array<array<i32, 3>, 5>, 1> = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<array<i32, 3>, 5>, 1>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    loop [i: %b3, b: %b4, c: %b5] {  # loop_1
+      %b3 = block {  # initializer
+        next_iteration %b4 %tint_local_index
+      }
+      %b4 = block (%4:u32) {  # body
+        %5:bool = gt %4:u32, 14u
+        if %5 [t: %b6] {  # if_1
+          %b6 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %6:u32 = mod %4:u32, 3u
+        %7:u32 = div %4:u32, 3u
+        %8:ptr<workgroup, i32, read_write> = access %wgvar, 0u, %7, %6
+        store %8, 0i
+        continue %b5
+      }
+      %b5 = block {  # continuing
+        %9:u32 = add %4:u32, 1u
+        next_iteration %b4 %9
+      }
+    }
+    %10:void = workgroupBarrier
+    %11:array<array<array<i32, 3>, 5>, 1> = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedArrayTotalSizeOne) {
+    auto* var = MakeVar("wgvar", ty.array(ty.array<i32, 1>(), 1));
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<i32, 1>, 1>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:array<array<i32, 1>, 1> = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<array<i32, 1>, 1>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        %5:ptr<workgroup, i32, read_write> = access %wgvar, 0u, 0u
+        store %5, 0i
+        exit_if  # if_1
+      }
+    }
+    %6:void = workgroupBarrier
+    %7:array<array<i32, 1>, 1> = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, StructOfScalars) {
+    auto* s = ty.Struct(mod.symbols.New("MyStruct"), {
+                                                         {mod.symbols.New("a"), ty.i32()},
+                                                         {mod.symbols.New("b"), ty.u32()},
+                                                         {mod.symbols.New("c"), ty.f32()},
+                                                     });
+    auto* var = MakeVar("wgvar", s);
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+MyStruct = struct @align(4) {
+  a:i32 @offset(0)
+  b:u32 @offset(4)
+  c:f32 @offset(8)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, MyStruct, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:MyStruct = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+MyStruct = struct @align(4) {
+  a:i32 @offset(0)
+  b:u32 @offset(4)
+  c:f32 @offset(8)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, MyStruct, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, MyStruct(0i, 0u, 0.0f)
+        exit_if  # if_1
+      }
+    }
+    %5:void = workgroupBarrier
+    %6:MyStruct = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedStructOfScalars) {
+    auto* inner = ty.Struct(mod.symbols.New("Inner"), {
+                                                          {mod.symbols.New("a"), ty.i32()},
+                                                          {mod.symbols.New("b"), ty.u32()},
+                                                      });
+    auto* outer = ty.Struct(mod.symbols.New("Outer"), {
+                                                          {mod.symbols.New("c"), ty.f32()},
+                                                          {mod.symbols.New("inner"), inner},
+                                                          {mod.symbols.New("d"), ty.bool_()},
+                                                      });
+    auto* var = MakeVar("wgvar", outer);
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+Inner = struct @align(4) {
+  a:i32 @offset(0)
+  b:u32 @offset(4)
+}
+
+Outer = struct @align(4) {
+  c:f32 @offset(0)
+  inner:Inner @offset(4)
+  d:bool @offset(12)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, Outer, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:Outer = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+Inner = struct @align(4) {
+  a:i32 @offset(0)
+  b:u32 @offset(4)
+}
+
+Outer = struct @align(4) {
+  c:f32 @offset(0)
+  inner:Inner @offset(4)
+  d:bool @offset(12)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, Outer, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, Outer(0.0f, Inner(0i, 0u), false)
+        exit_if  # if_1
+      }
+    }
+    %5:void = workgroupBarrier
+    %6:Outer = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedStructOfScalarsWithAtomic) {
+    auto* inner = ty.Struct(mod.symbols.New("Inner"), {
+                                                          {mod.symbols.New("a"), ty.i32()},
+                                                          {mod.symbols.New("b"), ty.atomic<u32>()},
+                                                      });
+    auto* outer = ty.Struct(mod.symbols.New("Outer"), {
+                                                          {mod.symbols.New("c"), ty.f32()},
+                                                          {mod.symbols.New("inner"), inner},
+                                                          {mod.symbols.New("d"), ty.bool_()},
+                                                      });
+    auto* var = MakeVar("wgvar", outer);
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+Inner = struct @align(4) {
+  a:i32 @offset(0)
+  b:atomic<u32> @offset(4)
+}
+
+Outer = struct @align(4) {
+  c:f32 @offset(0)
+  inner:Inner @offset(4)
+  d:bool @offset(12)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, Outer, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    %3:Outer = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+Inner = struct @align(4) {
+  a:i32 @offset(0)
+  b:atomic<u32> @offset(4)
+}
+
+Outer = struct @align(4) {
+  c:f32 @offset(0)
+  inner:Inner @offset(4)
+  d:bool @offset(12)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, Outer, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        %5:ptr<workgroup, f32, read_write> = access %wgvar, 0u
+        store %5, 0.0f
+        %6:ptr<workgroup, i32, read_write> = access %wgvar, 1u, 0u
+        store %6, 0i
+        %7:ptr<workgroup, atomic<u32>, read_write> = access %wgvar, 1u, 1u
+        %8:void = atomicStore %7, 0u
+        %9:ptr<workgroup, bool, read_write> = access %wgvar, 2u
+        store %9, false
+        exit_if  # if_1
+      }
+    }
+    %10:void = workgroupBarrier
+    %11:Outer = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ArrayOfStructOfArrayOfStructWithAtomic) {
+    auto* inner = ty.Struct(mod.symbols.New("Inner"), {
+                                                          {mod.symbols.New("a"), ty.i32()},
+                                                          {mod.symbols.New("b"), ty.atomic<u32>()},
+                                                      });
+    auto* outer =
+        ty.Struct(mod.symbols.New("Outer"), {
+                                                {mod.symbols.New("c"), ty.f32()},
+                                                {mod.symbols.New("inner"), ty.array(inner, 13)},
+                                                {mod.symbols.New("d"), ty.bool_()},
+                                            });
+    auto* var = MakeVar("wgvar", ty.array(outer, 7));
+
+    auto* func = MakeEntryPoint("main", 7, 3, 2);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+Inner = struct @align(4) {
+  a:i32 @offset(0)
+  b:atomic<u32> @offset(4)
+}
+
+Outer = struct @align(4) {
+  c:f32 @offset(0)
+  inner:array<Inner, 13> @offset(4)
+  d:bool @offset(108)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<Outer, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(7, 3, 2) func():void -> %b2 {
+  %b2 = block {
+    %3:array<Outer, 7> = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+Inner = struct @align(4) {
+  a:i32 @offset(0)
+  b:atomic<u32> @offset(4)
+}
+
+Outer = struct @align(4) {
+  c:f32 @offset(0)
+  inner:array<Inner, 13> @offset(4)
+  d:bool @offset(108)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, array<Outer, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(7, 3, 2) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    loop [i: %b3, b: %b4, c: %b5] {  # loop_1
+      %b3 = block {  # initializer
+        next_iteration %b4 %tint_local_index
+      }
+      %b4 = block (%4:u32) {  # body
+        %5:bool = gt %4:u32, 6u
+        if %5 [t: %b6] {  # if_1
+          %b6 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %6:ptr<workgroup, f32, read_write> = access %wgvar, %4:u32, 0u
+        store %6, 0.0f
+        %7:ptr<workgroup, bool, read_write> = access %wgvar, %4:u32, 2u
+        store %7, false
+        continue %b5
+      }
+      %b5 = block {  # continuing
+        %8:u32 = add %4:u32, 42u
+        next_iteration %b4 %8
+      }
+    }
+    loop [i: %b7, b: %b8, c: %b9] {  # loop_2
+      %b7 = block {  # initializer
+        next_iteration %b8 %tint_local_index
+      }
+      %b8 = block (%9:u32) {  # body
+        %10:bool = gt %9:u32, 90u
+        if %10 [t: %b10] {  # if_2
+          %b10 = block {  # true
+            exit_loop  # loop_2
+          }
+        }
+        %11:u32 = mod %9:u32, 13u
+        %12:u32 = div %9:u32, 13u
+        %13:ptr<workgroup, i32, read_write> = access %wgvar, %12, 1u, %11, 0u
+        store %13, 0i
+        %14:u32 = mod %9:u32, 13u
+        %15:u32 = div %9:u32, 13u
+        %16:ptr<workgroup, atomic<u32>, read_write> = access %wgvar, %15, 1u, %14, 1u
+        %17:void = atomicStore %16, 0u
+        continue %b9
+      }
+      %b9 = block {  # continuing
+        %18:u32 = add %9:u32, 42u
+        next_iteration %b8 %18
+      }
+    }
+    %19:void = workgroupBarrier
+    %20:array<Outer, 7> = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, MultipleVariables_DifferentIterationCounts) {
+    auto* var_a = MakeVar("var_a", ty.bool_());
+    auto* var_b = MakeVar("var_b", ty.array<i32, 4>());
+    auto* var_c = MakeVar("var_c", ty.array(ty.array<u32, 5>(), 7));
+
+    auto* func = MakeEntryPoint("main", 11, 2, 3);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var_a);
+        b.Load(var_b);
+        b.Load(var_c);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %var_a:ptr<workgroup, bool, read_write> = var
+  %var_b:ptr<workgroup, array<i32, 4>, read_write> = var
+  %var_c:ptr<workgroup, array<array<u32, 5>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func():void -> %b2 {
+  %b2 = block {
+    %5:bool = load %var_a
+    %6:array<i32, 4> = load %var_b
+    %7:array<array<u32, 5>, 7> = load %var_c
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %var_a:ptr<workgroup, bool, read_write> = var
+  %var_b:ptr<workgroup, array<i32, 4>, read_write> = var
+  %var_c:ptr<workgroup, array<array<u32, 5>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %6:bool = eq %tint_local_index, 0u
+    if %6 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %var_a, false
+        exit_if  # if_1
+      }
+    }
+    loop [i: %b4, b: %b5, c: %b6] {  # loop_1
+      %b4 = block {  # initializer
+        next_iteration %b5 %tint_local_index
+      }
+      %b5 = block (%7:u32) {  # body
+        %8:bool = gt %7:u32, 3u
+        if %8 [t: %b7] {  # if_2
+          %b7 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %9:ptr<workgroup, i32, read_write> = access %var_b, %7:u32
+        store %9, 0i
+        continue %b6
+      }
+      %b6 = block {  # continuing
+        %10:u32 = add %7:u32, 66u
+        next_iteration %b5 %10
+      }
+    }
+    loop [i: %b8, b: %b9, c: %b10] {  # loop_2
+      %b8 = block {  # initializer
+        next_iteration %b9 %tint_local_index
+      }
+      %b9 = block (%11:u32) {  # body
+        %12:bool = gt %11:u32, 34u
+        if %12 [t: %b11] {  # if_3
+          %b11 = block {  # true
+            exit_loop  # loop_2
+          }
+        }
+        %13:u32 = mod %11:u32, 5u
+        %14:u32 = div %11:u32, 5u
+        %15:ptr<workgroup, u32, read_write> = access %var_c, %14, %13
+        store %15, 0u
+        continue %b10
+      }
+      %b10 = block {  # continuing
+        %16:u32 = add %11:u32, 66u
+        next_iteration %b9 %16
+      }
+    }
+    %17:void = workgroupBarrier
+    %18:bool = load %var_a
+    %19:array<i32, 4> = load %var_b
+    %20:array<array<u32, 5>, 7> = load %var_c
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, MultipleVariables_SharedIterationCounts) {
+    auto* var_a = MakeVar("var_a", ty.bool_());
+    auto* var_b = MakeVar("var_b", ty.i32());
+    auto* var_c = MakeVar("var_c", ty.array<i32, 42>());
+    auto* var_d = MakeVar("var_d", ty.array(ty.array<u32, 6>(), 7));
+
+    auto* func = MakeEntryPoint("main", 11, 2, 3);
+    b.Append(func->Block(), [&] {  //
+        b.Load(var_a);
+        b.Load(var_b);
+        b.Load(var_c);
+        b.Load(var_d);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %var_a:ptr<workgroup, bool, read_write> = var
+  %var_b:ptr<workgroup, i32, read_write> = var
+  %var_c:ptr<workgroup, array<i32, 42>, read_write> = var
+  %var_d:ptr<workgroup, array<array<u32, 6>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func():void -> %b2 {
+  %b2 = block {
+    %6:bool = load %var_a
+    %7:i32 = load %var_b
+    %8:array<i32, 42> = load %var_c
+    %9:array<array<u32, 6>, 7> = load %var_d
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %var_a:ptr<workgroup, bool, read_write> = var
+  %var_b:ptr<workgroup, i32, read_write> = var
+  %var_c:ptr<workgroup, array<i32, 42>, read_write> = var
+  %var_d:ptr<workgroup, array<array<u32, 6>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %7:bool = eq %tint_local_index, 0u
+    if %7 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %var_a, false
+        store %var_b, 0i
+        exit_if  # if_1
+      }
+    }
+    loop [i: %b4, b: %b5, c: %b6] {  # loop_1
+      %b4 = block {  # initializer
+        next_iteration %b5 %tint_local_index
+      }
+      %b5 = block (%8:u32) {  # body
+        %9:bool = gt %8:u32, 41u
+        if %9 [t: %b7] {  # if_2
+          %b7 = block {  # true
+            exit_loop  # loop_1
+          }
+        }
+        %10:ptr<workgroup, i32, read_write> = access %var_c, %8:u32
+        store %10, 0i
+        %11:u32 = mod %8:u32, 6u
+        %12:u32 = div %8:u32, 6u
+        %13:ptr<workgroup, u32, read_write> = access %var_d, %12, %11
+        store %13, 0u
+        continue %b6
+      }
+      %b6 = block {  # continuing
+        %14:u32 = add %8:u32, 66u
+        next_iteration %b5 %14
+      }
+    }
+    %15:void = workgroupBarrier
+    %16:bool = load %var_a
+    %17:i32 = load %var_b
+    %18:array<i32, 42> = load %var_c
+    %19:array<array<u32, 6>, 7> = load %var_d
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ExistingLocalInvocationIndex) {
+    auto* var = MakeVar("wgvar", ty.bool_());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    auto* global_id = b.FunctionParam("global_id", ty.vec4<u32>());
+    global_id->SetBuiltin(FunctionParam::Builtin::kGlobalInvocationId);
+    auto* index = b.FunctionParam("index", ty.u32());
+    index->SetBuiltin(FunctionParam::Builtin::kLocalInvocationIndex);
+    func->SetParams({global_id, index});
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%global_id:vec4<u32> [@global_invocation_id], %index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %5:bool = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%global_id:vec4<u32> [@global_invocation_id], %index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %5:bool = eq %index, 0u
+    if %5 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, false
+        exit_if  # if_1
+      }
+    }
+    %6:void = workgroupBarrier
+    %7:bool = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ExistingLocalInvocationIndexInStruct) {
+    auto* var = MakeVar("wgvar", ty.bool_());
+
+    auto* structure =
+        ty.Struct(mod.symbols.New("MyStruct"),
+                  {
+                      {
+                          mod.symbols.New("global_id"),
+                          ty.vec3<u32>(),
+                          {{}, {}, core::BuiltinValue::kGlobalInvocationId, {}, false},
+                      },
+                      {
+                          mod.symbols.New("index"),
+                          ty.u32(),
+                          {{}, {}, core::BuiltinValue::kLocalInvocationIndex, {}, false},
+                      },
+                  });
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    func->SetParams({b.FunctionParam("params", structure)});
+    b.Append(func->Block(), [&] {  //
+        b.Load(var);
+        b.Return(func);
+    });
+
+    auto* src = R"(
+MyStruct = struct @align(16) {
+  global_id:vec3<u32> @offset(0), @builtin(global_invocation_id)
+  index:u32 @offset(12), @builtin(local_invocation_index)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%params:MyStruct):void -> %b2 {
+  %b2 = block {
+    %4:bool = load %wgvar
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+MyStruct = struct @align(16) {
+  global_id:vec3<u32> @offset(0), @builtin(global_invocation_id)
+  index:u32 @offset(12), @builtin(local_invocation_index)
+}
+
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%params:MyStruct):void -> %b2 {
+  %b2 = block {
+    %4:u32 = access %params, 1u
+    %5:bool = eq %4, 0u
+    if %5 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, false
+        exit_if  # if_1
+      }
+    }
+    %6:void = workgroupBarrier
+    %7:bool = load %wgvar
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, UseInsideNestedBlock) {
+    auto* var = MakeVar("wgvar", ty.bool_());
+
+    auto* func = MakeEntryPoint("main", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        auto* ifelse = b.If(true);
+        b.Append(ifelse->True(), [&] {  //
+            auto* sw = b.Switch(42_i);
+            auto* def_case = b.Case(sw, Vector{core::ir::Switch::CaseSelector()});
+            b.Append(def_case, [&] {  //
+                auto* loop = b.Loop();
+                b.Append(loop->Body(), [&] {  //
+                    b.Continue(loop);
+                    b.Append(loop->Continuing(), [&] {  //
+                        auto* load = b.Load(var);
+                        b.BreakIf(loop, load);
+                    });
+                });
+                b.ExitSwitch(sw);
+            });
+            b.ExitIf(ifelse);
+        });
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+  %b2 = block {
+    if true [t: %b3] {  # if_1
+      %b3 = block {  # true
+        switch 42i [c: (default, %b4)] {  # switch_1
+          %b4 = block {  # case
+            loop [b: %b5, c: %b6] {  # loop_1
+              %b5 = block {  # body
+                continue %b6
+              }
+              %b6 = block {  # continuing
+                %3:bool = load %wgvar
+                break_if %3 %b5
+              }
+            }
+            exit_switch  # switch_1
+          }
+        }
+        exit_if  # if_1
+      }
+    }
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+  %b2 = block {
+    %4:bool = eq %tint_local_index, 0u
+    if %4 [t: %b3] {  # if_1
+      %b3 = block {  # true
+        store %wgvar, false
+        exit_if  # if_1
+      }
+    }
+    %5:void = workgroupBarrier
+    if true [t: %b4] {  # if_2
+      %b4 = block {  # true
+        switch 42i [c: (default, %b5)] {  # switch_1
+          %b5 = block {  # case
+            loop [b: %b6, c: %b7] {  # loop_1
+              %b6 = block {  # body
+                continue %b7
+              }
+              %b7 = block {  # continuing
+                %6:bool = load %wgvar
+                break_if %6 %b6
+              }
+            }
+            exit_switch  # switch_1
+          }
+        }
+        exit_if  # if_2
+      }
+    }
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, UseInsideIndirectFunctionCall) {
+    auto* var = MakeVar("wgvar", ty.bool_());
+
+    auto* foo = b.Function("foo", ty.void_());
+    b.Append(foo->Block(), [&] {  //
+        auto* loop = b.Loop();
+        b.Append(loop->Body(), [&] {  //
+            b.Continue(loop);
+            b.Append(loop->Continuing(), [&] {  //
+                auto* load = b.Load(var);
+                b.BreakIf(loop, load);
+            });
+        });
+        b.Return(foo);
+    });
+
+    auto* bar = b.Function("foo", ty.void_());
+    b.Append(bar->Block(), [&] {  //
+        auto* ifelse = b.If(true);
+        b.Append(ifelse->True(), [&] {  //
+            b.Call(ty.void_(), foo);
+            b.ExitIf(ifelse);
+        });
+        b.Return(bar);
+    });
+
+    auto* func = MakeEntryPoint("func", 1, 1, 1);
+    b.Append(func->Block(), [&] {  //
+        auto* ifelse = b.If(true);
+        b.Append(ifelse->True(), [&] {  //
+            auto* sw = b.Switch(42_i);
+            auto* def_case = b.Case(sw, Vector{core::ir::Switch::CaseSelector()});
+            b.Append(def_case, [&] {  //
+                auto* loop = b.Loop();
+                b.Append(loop->Body(), [&] {  //
+                    b.Continue(loop);
+                    b.Append(loop->Continuing(), [&] {  //
+                        b.Call(ty.void_(), bar);
+                        b.BreakIf(loop, true);
+                    });
+                });
+                b.ExitSwitch(sw);
+            });
+            b.ExitIf(ifelse);
+        });
+        b.Return(func);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+  %b2 = block {
+    loop [b: %b3, c: %b4] {  # loop_1
+      %b3 = block {  # body
+        continue %b4
+      }
+      %b4 = block {  # continuing
+        %3:bool = load %wgvar
+        break_if %3 %b3
+      }
+    }
+    ret
+  }
+}
+%foo_1 = func():void -> %b5 {  # %foo_1: 'foo'
+  %b5 = block {
+    if true [t: %b6] {  # if_1
+      %b6 = block {  # true
+        %5:void = call %foo
+        exit_if  # if_1
+      }
+    }
+    ret
+  }
+}
+%func = @compute @workgroup_size(1, 1, 1) func():void -> %b7 {
+  %b7 = block {
+    if true [t: %b8] {  # if_2
+      %b8 = block {  # true
+        switch 42i [c: (default, %b9)] {  # switch_1
+          %b9 = block {  # case
+            loop [b: %b10, c: %b11] {  # loop_2
+              %b10 = block {  # body
+                continue %b11
+              }
+              %b11 = block {  # continuing
+                %7:void = call %foo_1
+                break_if true %b10
+              }
+            }
+            exit_switch  # switch_1
+          }
+        }
+        exit_if  # if_2
+      }
+    }
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+  %b2 = block {
+    loop [b: %b3, c: %b4] {  # loop_1
+      %b3 = block {  # body
+        continue %b4
+      }
+      %b4 = block {  # continuing
+        %3:bool = load %wgvar
+        break_if %3 %b3
+      }
+    }
+    ret
+  }
+}
+%foo_1 = func():void -> %b5 {  # %foo_1: 'foo'
+  %b5 = block {
+    if true [t: %b6] {  # if_1
+      %b6 = block {  # true
+        %5:void = call %foo
+        exit_if  # if_1
+      }
+    }
+    ret
+  }
+}
+%func = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b7 {
+  %b7 = block {
+    %8:bool = eq %tint_local_index, 0u
+    if %8 [t: %b8] {  # if_2
+      %b8 = block {  # true
+        store %wgvar, false
+        exit_if  # if_2
+      }
+    }
+    %9:void = workgroupBarrier
+    if true [t: %b9] {  # if_3
+      %b9 = block {  # true
+        switch 42i [c: (default, %b10)] {  # switch_1
+          %b10 = block {  # case
+            loop [b: %b11, c: %b12] {  # loop_2
+              %b11 = block {  # body
+                continue %b12
+              }
+              %b12 = block {  # continuing
+                %10:void = call %foo_1
+                break_if true %b11
+              }
+            }
+            exit_switch  # switch_1
+          }
+        }
+        exit_if  # if_3
+      }
+    }
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, MultipleEntryPoints_SameVarViaHelper) {
+    auto* var = MakeVar("wgvar", ty.bool_());
+
+    auto* foo = b.Function("foo", ty.void_());
+    b.Append(foo->Block(), [&] {  //
+        auto* loop = b.Loop();
+        b.Append(loop->Body(), [&] {  //
+            b.Continue(loop);
+            b.Append(loop->Continuing(), [&] {  //
+                auto* load = b.Load(var);
+                b.BreakIf(loop, load);
+            });
+        });
+        b.Return(foo);
+    });
+
+    auto* ep1 = MakeEntryPoint("ep1", 1, 1, 1);
+    b.Append(ep1->Block(), [&] {  //
+        b.Call(ty.void_(), foo);
+        b.Return(ep1);
+    });
+
+    auto* ep2 = MakeEntryPoint("ep2", 1, 1, 1);
+    b.Append(ep2->Block(), [&] {  //
+        b.Call(ty.void_(), foo);
+        b.Return(ep2);
+    });
+
+    auto* src = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+  %b2 = block {
+    loop [b: %b3, c: %b4] {  # loop_1
+      %b3 = block {  # body
+        continue %b4
+      }
+      %b4 = block {  # continuing
+        %3:bool = load %wgvar
+        break_if %3 %b3
+      }
+    }
+    ret
+  }
+}
+%ep1 = @compute @workgroup_size(1, 1, 1) func():void -> %b5 {
+  %b5 = block {
+    %5:void = call %foo
+    ret
+  }
+}
+%ep2 = @compute @workgroup_size(1, 1, 1) func():void -> %b6 {
+  %b6 = block {
+    %7:void = call %foo
+    ret
+  }
+}
+)";
+    EXPECT_EQ(src, str());
+
+    auto* expect = R"(
+%b1 = block {  # root
+  %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+  %b2 = block {
+    loop [b: %b3, c: %b4] {  # loop_1
+      %b3 = block {  # body
+        continue %b4
+      }
+      %b4 = block {  # continuing
+        %3:bool = load %wgvar
+        break_if %3 %b3
+      }
+    }
+    ret
+  }
+}
+%ep1 = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b5 {
+  %b5 = block {
+    %6:bool = eq %tint_local_index, 0u
+    if %6 [t: %b6] {  # if_1
+      %b6 = block {  # true
+        store %wgvar, false
+        exit_if  # if_1
+      }
+    }
+    %7:void = workgroupBarrier
+    %8:void = call %foo
+    ret
+  }
+}
+%ep2 = @compute @workgroup_size(1, 1, 1) func(%tint_local_index_1:u32 [@local_invocation_index]):void -> %b7 {  # %tint_local_index_1: 'tint_local_index'
+  %b7 = block {
+    %11:bool = eq %tint_local_index_1, 0u
+    if %11 [t: %b8] {  # if_2
+      %b8 = block {  # true
+        store %wgvar, false
+        exit_if  # if_2
+      }
+    }
+    %12:void = workgroupBarrier
+    %13:void = call %foo
+    ret
+  }
+}
+)";
+
+    Run(ZeroInitWorkgroupMemory);
+
+    EXPECT_EQ(expect, str());
+}
+
+}  // namespace
+}  // namespace tint::core::ir::transform
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index 6a5c39f..93989a4 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -25,6 +25,7 @@
 #include "src/tint/lang/core/ir/transform/multiplanar_external_texture.h"
 #include "src/tint/lang/core/ir/transform/robustness.h"
 #include "src/tint/lang/core/ir/transform/std140.h"
+#include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
 #include "src/tint/lang/spirv/writer/raise/builtin_polyfill.h"
 #include "src/tint/lang/spirv/writer/raise/expand_implicit_splats.h"
 #include "src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.h"
@@ -70,6 +71,11 @@
     RUN_TRANSFORM(core::ir::transform::MultiplanarExternalTexture, module,
                   options.external_texture_options);
 
+    if (!options.disable_workgroup_init &&
+        !options.use_zero_initialize_workgroup_memory_extension) {
+        RUN_TRANSFORM(core::ir::transform::ZeroInitWorkgroupMemory, module);
+    }
+
     RUN_TRANSFORM(core::ir::transform::AddEmptyEntryPoint, module);
     RUN_TRANSFORM(core::ir::transform::Bgra8UnormPolyfill, module);
     RUN_TRANSFORM(core::ir::transform::BlockDecoratedStructs, module);