[ir] Add ZeroInitWorkgroupMemory transform
Use it in the SPIR-V writer.
Bug: tint:1718, tint:1906
Change-Id: Ie1fe7491518aa0d793f45eda52f6a82bcd5485e3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/152463
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/transform/BUILD.bazel b/src/tint/lang/core/ir/transform/BUILD.bazel
index 7216985..4388198 100644
--- a/src/tint/lang/core/ir/transform/BUILD.bazel
+++ b/src/tint/lang/core/ir/transform/BUILD.bazel
@@ -37,6 +37,7 @@
"robustness.cc",
"shader_io.cc",
"std140.cc",
+ "zero_init_workgroup_memory.cc",
],
hdrs = [
"add_empty_entry_point.h",
@@ -50,6 +51,7 @@
"robustness.h",
"shader_io.h",
"std140.h",
+ "zero_init_workgroup_memory.h",
],
deps = [
"//src/tint/api/common",
@@ -92,6 +94,7 @@
"multiplanar_external_texture_test.cc",
"robustness_test.cc",
"std140_test.cc",
+ "zero_init_workgroup_memory_test.cc",
],
deps = [
"//src/tint/api/common",
diff --git a/src/tint/lang/core/ir/transform/BUILD.cmake b/src/tint/lang/core/ir/transform/BUILD.cmake
index f474b48..5e772df 100644
--- a/src/tint/lang/core/ir/transform/BUILD.cmake
+++ b/src/tint/lang/core/ir/transform/BUILD.cmake
@@ -48,6 +48,8 @@
lang/core/ir/transform/shader_io.h
lang/core/ir/transform/std140.cc
lang/core/ir/transform/std140.h
+ lang/core/ir/transform/zero_init_workgroup_memory.cc
+ lang/core/ir/transform/zero_init_workgroup_memory.h
)
tint_target_add_dependencies(tint_lang_core_ir_transform lib
@@ -90,6 +92,7 @@
lang/core/ir/transform/multiplanar_external_texture_test.cc
lang/core/ir/transform/robustness_test.cc
lang/core/ir/transform/std140_test.cc
+ lang/core/ir/transform/zero_init_workgroup_memory_test.cc
)
tint_target_add_dependencies(tint_lang_core_ir_transform_test test
diff --git a/src/tint/lang/core/ir/transform/BUILD.gn b/src/tint/lang/core/ir/transform/BUILD.gn
index 82e462b..cf99aba 100644
--- a/src/tint/lang/core/ir/transform/BUILD.gn
+++ b/src/tint/lang/core/ir/transform/BUILD.gn
@@ -53,6 +53,8 @@
"shader_io.h",
"std140.cc",
"std140.h",
+ "zero_init_workgroup_memory.cc",
+ "zero_init_workgroup_memory.h",
]
deps = [
"${tint_src_dir}/api/common",
@@ -93,6 +95,7 @@
"multiplanar_external_texture_test.cc",
"robustness_test.cc",
"std140_test.cc",
+ "zero_init_workgroup_memory_test.cc",
]
deps = [
"${tint_src_dir}:gmock_and_gtest",
diff --git a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
new file mode 100644
index 0000000..cd6c589
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.cc
@@ -0,0 +1,403 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
+
+#include <map>
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/utils/containers/reverse.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+namespace tint::core::ir::transform {
+
+namespace {
+
+/// PIMPL state for the transform.
+struct State {
+ /// The IR module.
+ Module& ir;
+
+ /// The IR builder.
+ Builder b{ir};
+
+ /// The type manager.
+ core::type::Manager& ty{ir.Types()};
+
+ /// VarSet is a hash set of workgroup variables.
+ using VarSet = Hashset<Var*, 8>;
+
+ /// A map from variable to an ID used for sorting.
+ Hashmap<Var*, uint32_t, 8> var_to_id{};
+
+ /// A map from blocks to their directly referenced workgroup variables.
+ Hashmap<Block*, VarSet, 64> block_to_direct_vars{};
+
+ /// A map from functions to their transitively referenced workgroup variables.
+ Hashmap<Function*, VarSet, 8> function_to_transitive_vars{};
+
+ /// ArrayIndex represents a required array index for an access instruction.
+ struct ArrayIndex {
+ /// The size of the array that will be indexed.
+ uint32_t count = 0u;
+ };
+
+ /// Index represents an index for an access instruction, which is either a constant value or
+ /// an array index that will be dynamically calculated from an array size.
+ using Index = std::variant<uint32_t, ArrayIndex>;
+
+ /// Store describes a store to a sub-element of a workgroup variable.
+ struct Store {
+ /// The workgroup variable.
+ Var* var = nullptr;
+ /// The store type of the element.
+ const type::Type* store_type = nullptr;
+ /// The list of index operands to get to the element.
+ Vector<Index, 4> indices;
+ };
+
+ /// StoreList is a list of `Store` descriptors.
+ using StoreList = Vector<Store, 8>;
+
+ /// StoreMap is a map from iteration count to a list of `Store` descriptors.
+ using StoreMap = Hashmap<uint32_t, StoreList, 8>;
+
+ /// Process the module.
+ void Process() {
+ if (!ir.root_block) {
+ return;
+ }
+
+ // Loop over module-scope variables, looking for workgroup variables.
+ uint32_t next_id = 0;
+ for (auto inst : *ir.root_block) {
+ if (auto* var = inst->As<Var>()) {
+ auto* ptr = var->Result()->Type()->As<core::type::Pointer>();
+ if (ptr && ptr->AddressSpace() == core::AddressSpace::kWorkgroup) {
+ // Record the usage of the variable for each block that references it.
+ var->Result()->ForEachUse([&](const Usage& use) {
+ block_to_direct_vars.GetOrZero(use.instruction->Block())->Add(var);
+ });
+ var_to_id.Add(var, next_id++);
+ }
+ }
+ }
+
+ // Process each entry point function.
+ for (auto* func : ir.functions) {
+ if (func->Stage() == Function::PipelineStage::kCompute) {
+ ProcessEntryPoint(func);
+ }
+ }
+ }
+
+ /// Process an entry point function to zero-initialize the workgroup variables that it uses.
+ /// @param func the entry point function
+ void ProcessEntryPoint(Function* func) {
+ // Get list of transitively referenced workgroup variables.
+ auto vars = GetReferencedVars(func);
+ if (vars.IsEmpty()) {
+ return;
+ }
+
+ // Sort the variables to get deterministic output in tests.
+ auto sorted_vars = vars.Vector();
+ sorted_vars.Sort([&](Var* first, Var* second) {
+ return *var_to_id.Get(first) < *var_to_id.Get(second);
+ });
+
+ // Build list of store descriptors for all workgroup variables.
+ StoreMap stores;
+ for (auto* var : sorted_vars) {
+ PrepareStores(var, var->Result()->Type()->UnwrapPtr(), 1, {}, stores);
+ }
+
+ // Sort the iteration counts to get deterministic output in tests.
+ auto sorted_iteration_counts = stores.Keys();
+ sorted_iteration_counts.Sort();
+
+ // Capture the first instruction of the function.
+ // All new instructions will be inserted before this.
+ auto* function_start = func->Block()->Front();
+
+ // Get the local invocation index and the linearized workgroup size.
+ auto* local_index = GetLocalInvocationIndex(func);
+ auto wgsizes = func->WorkgroupSize().value();
+ auto wgsize = wgsizes[0] * wgsizes[1] * wgsizes[2];
+
+ // Insert instructions to zero-initialize every variable.
+ b.InsertBefore(function_start, [&] {
+ for (auto count : sorted_iteration_counts) {
+ auto element_stores = stores.Get(count);
+ if (count == 1u) {
+ // Make the first invocation in the group perform all of the non-arrayed stores.
+ auto* ifelse = b.If(b.Equal(ty.bool_(), local_index, 0_u));
+ b.Append(ifelse->True(), [&] {
+ for (auto& store : *element_stores) {
+ GenerateStore(store, count, b.Constant(0_u));
+ }
+ b.ExitIf(ifelse);
+ });
+ } else {
+ // Use a loop for arrayed stores.
+ GenerateZeroingLoop(local_index, count, wgsize, *element_stores);
+ }
+ }
+ b.Call(ty.void_(), core::BuiltinFn::kWorkgroupBarrier);
+ });
+ }
+
+ /// Get the set of workgroup variables transitively referenced by @p func.
+ /// @param func the function
+ /// @returns the set of transitively referenced workgroup variables
+ VarSet GetReferencedVars(Function* func) {
+ return function_to_transitive_vars.GetOrCreate(func, [&] {
+ VarSet vars;
+ GetReferencedVars(func->Block(), vars);
+ return vars;
+ });
+ }
+
+ /// Get the set of workgroup variables transitively referenced by @p block.
+ /// @param block the block
+ /// @param vars the set of transitively referenced workgroup variables to populate
+ void GetReferencedVars(Block* block, VarSet& vars) {
+ // Add directly referenced vars.
+ if (auto itr = block_to_direct_vars.Find(block)) {
+ for (auto* var : *itr) {
+ vars.Add(var);
+ }
+ }
+
+ // Loop over instructions in the block.
+ for (auto* inst : *block) {
+ tint::Switch(
+ inst,
+ [&](UserCall* call) {
+ // Get variables referenced by a function called from this block.
+ auto callee_vars = GetReferencedVars(call->Target());
+ for (auto* var : callee_vars) {
+ vars.Add(var);
+ }
+ },
+ [&](ControlInstruction* ctrl) {
+ // Recurse into control instructions and gather their referenced vars.
+ ctrl->ForeachBlock([&](Block* blk) { GetReferencedVars(blk, vars); });
+ });
+ }
+ }
+
+ /// Recursively generate store descriptors for a workgroup variable.
+ /// Determines the combined array iteration count of each inner element.
+ /// @param var the workgroup variable
+ /// @param type the current element type
+ /// @param iteration_count the iteration count of this inner element of the variable
+ /// @param indices the access indices needed to get to this element
+ /// @param stores the map of stores to populate
+ void PrepareStores(Var* var,
+ const type::Type* type,
+ uint32_t iteration_count,
+ Vector<Index, 4> indices,
+ StoreMap& stores) {
+ // If this type can be trivially zeroed, store to the whole element.
+ if (CanTriviallyZero(type)) {
+ stores.GetOrZero(iteration_count)->Push(Store{var, type, indices});
+ return;
+ }
+
+ tint::Switch(
+ type,
+ [&](const type::Array* arr) {
+ // Add an array index to the list and recurse into the element type.
+ TINT_ASSERT(arr->ConstantCount());
+ auto count = arr->ConstantCount().value();
+ auto new_indices = indices;
+ if (count > 1) {
+ new_indices.Push(ArrayIndex{count});
+ } else {
+ new_indices.Push(0u);
+ }
+ PrepareStores(var, arr->ElemType(), iteration_count * count, new_indices, stores);
+ },
+ [&](const type::Atomic*) {
+ stores.GetOrZero(iteration_count)->Push(Store{var, type, indices});
+ },
+ [&](const type::Struct* str) {
+ for (auto* member : str->Members()) {
+ // Add the member index to the index list and recurse into its type.
+ auto new_indices = indices;
+ new_indices.Push(member->Index());
+ PrepareStores(var, member->Type(), iteration_count, new_indices, stores);
+ }
+ },
+ [&](Default) { TINT_UNREACHABLE(); });
+ }
+
+ /// Get or inject an entry point builtin for the local invocation index.
+ /// @param func the entry point function
+ /// @returns the local invocation index builtin
+ Value* GetLocalInvocationIndex(Function* func) {
+ // Look for an existing local_invocation_index builtin parameter.
+ for (auto* param : func->Params()) {
+ if (auto* str = param->Type()->As<type::Struct>()) {
+ // Check each member for the local invocation index builtin attribute.
+ for (auto* member : str->Members()) {
+ if (member->Attributes().builtin && member->Attributes().builtin.value() ==
+ BuiltinValue::kLocalInvocationIndex) {
+ auto* access = b.Access(ty.u32(), param, u32(member->Index()));
+ access->InsertBefore(func->Block()->Front());
+ return access->Result();
+ }
+ }
+ } else {
+ // Check if the parameter is the local invocation index.
+ if (param->Builtin() &&
+ param->Builtin().value() == FunctionParam::Builtin::kLocalInvocationIndex) {
+ return param;
+ }
+ }
+ }
+
+ // No local invocation index was found, so add one to the parameter list and use that.
+ Vector<FunctionParam*, 4> params = func->Params();
+ auto* param = b.FunctionParam("tint_local_index", ty.u32());
+ param->SetBuiltin(FunctionParam::Builtin::kLocalInvocationIndex);
+ params.Push(param);
+ func->SetParams(params);
+ return param;
+ }
+
+ /// Generate the store instruction for a given store descriptor.
+ /// @param store the store descriptor
+ /// @param total_count the total number of elements that will be zeroed
+ /// @param linear_index the linear index of the single element that will be zeroed
+ void GenerateStore(const Store& store, uint32_t total_count, Value* linear_index) {
+ auto* to = store.var->Result();
+ if (!store.indices.IsEmpty()) {
+ // Build the access indices to get to the target element.
+ // We walk backwards along the index list so that adjacent invocation store to
+ // adjacent array elements.
+ uint32_t count = 1;
+ Vector<Value*, 4> indices;
+ for (auto idx : Reverse(store.indices)) {
+ if (std::holds_alternative<ArrayIndex>(idx)) {
+ // Array indices are computed from the linear index based on the size of the
+ // array and the size of the sub-arrays that have already been indexed.
+ auto array_index = std::get<ArrayIndex>(idx);
+ Value* index = linear_index;
+ if (count > 1) {
+ index = b.Divide(ty.u32(), index, u32(count))->Result();
+ }
+ if (total_count > count * array_index.count) {
+ index = b.Modulo(ty.u32(), index, u32(array_index.count))->Result();
+ }
+ indices.Push(index);
+ count *= array_index.count;
+ } else {
+ // Constant indices are added to the list unmodified.
+ indices.Push(b.Constant(u32(std::get<uint32_t>(idx))));
+ }
+ }
+ indices.Reverse();
+ to = b.Access(ty.ptr(workgroup, store.store_type), to, indices)->Result();
+ }
+
+ // Generate the store instruction.
+ if (auto* atomic = store.store_type->As<type::Atomic>()) {
+ auto* zero = b.Constant(ir.constant_values.Zero(atomic->Type()));
+ b.Call(ty.void_(), core::BuiltinFn::kAtomicStore, to, zero);
+ } else {
+ auto* zero = b.Constant(ir.constant_values.Zero(store.store_type));
+ b.Store(to, zero);
+ }
+ }
+
+ /// Generate a loop for a list of stores with the same iteration count.
+ /// @param local_index the local invocation index
+ /// @param total_count the number of iterations needed to store to all elements
+ /// @param wgsize the linearized workgroup size
+ /// @param stores the list of store descriptors
+ void GenerateZeroingLoop(Value* local_index,
+ uint32_t total_count,
+ uint32_t wgsize,
+ const StoreList& stores) {
+ // The loop is equivalent to:
+ // for (var idx = local_index; idx < linear_iteration_count; idx += wgsize) {
+ // <store to elements at `idx`>
+ // }
+ auto* loop = b.Loop();
+ auto* index = b.BlockParam(ty.u32());
+ loop->Body()->SetParams({index});
+ b.Append(loop->Initializer(), [&] { //
+ b.NextIteration(loop, local_index);
+ });
+ b.Append(loop->Body(), [&] {
+ // Exit the loop when the iteration count has been exceeded.
+ auto* gt_max = b.GreaterThan(ty.bool_(), index, u32(total_count - 1u));
+ auto* ifelse = b.If(gt_max);
+ b.Append(ifelse->True(), [&] { //
+ b.ExitLoop(loop);
+ });
+
+ // Insert all of the store instructions.
+ for (auto& store : stores) {
+ GenerateStore(store, total_count, index);
+ }
+
+ b.Continue(loop);
+ });
+ b.Append(loop->Continuing(), [&] { //
+ // Increment the loop index by linearized workgroup size.
+ b.NextIteration(loop, b.Add(ty.u32(), index, u32(wgsize)));
+ });
+ }
+
+ /// Check if a type can be efficiently zeroed with a single store. Returns `false` if there are
+ /// any nested arrays or atomics.
+ /// @param type the type to inspect
+ /// @returns true if a variable with store type @p ty can be efficiently zeroed
+ bool CanTriviallyZero(const core::type::Type* type) {
+ if (type->IsAnyOf<core::type::Atomic, core::type::Array>()) {
+ return false;
+ }
+ if (auto* str = type->As<core::type::Struct>()) {
+ for (auto* member : str->Members()) {
+ if (!CanTriviallyZero(member->Type())) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+};
+
+} // namespace
+
+Result<SuccessType> ZeroInitWorkgroupMemory(Module& ir) {
+ auto result = ValidateAndDumpIfNeeded(ir, "ZeroInitWorkgroupMemory transform");
+ if (!result) {
+ return result;
+ }
+
+ State{ir}.Process();
+
+ return Success;
+}
+
+} // namespace tint::core::ir::transform
diff --git a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h
new file mode 100644
index 0000000..02306c5
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h
@@ -0,0 +1,37 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_TINT_LANG_CORE_IR_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
+#define SRC_TINT_LANG_CORE_IR_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
+
+#include <string>
+
+#include "src/tint/utils/result/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}
+
+namespace tint::core::ir::transform {
+
+/// ZeroInitWorkgroupMemory is a transform that injects code at the top of each entry point to
+/// zero-initialize workgroup memory used by that entry point.
+/// @param module the module to transform
+/// @returns success or failure
+Result<SuccessType> ZeroInitWorkgroupMemory(Module& module);
+
+} // namespace tint::core::ir::transform
+
+#endif // SRC_TINT_LANG_CORE_IR_TRANSFORM_ZERO_INIT_WORKGROUP_MEMORY_H_
diff --git a/src/tint/lang/core/ir/transform/zero_init_workgroup_memory_test.cc b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory_test.cc
new file mode 100644
index 0000000..d2c556b
--- /dev/null
+++ b/src/tint/lang/core/ir/transform/zero_init_workgroup_memory_test.cc
@@ -0,0 +1,1940 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+
+namespace tint::core::ir::transform {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+class IR_ZeroInitWorkgroupMemoryTest : public TransformTest {
+ protected:
+ Function* MakeEntryPoint(const char* name,
+ uint32_t wgsize_x,
+ uint32_t wgsize_y,
+ uint32_t wgsize_z) {
+ auto* func = b.Function(name, ty.void_(), Function::PipelineStage::kCompute);
+ func->SetWorkgroupSize(wgsize_x, wgsize_y, wgsize_z);
+ return func;
+ }
+
+ Var* MakeVar(const char* name, const type::Type* store_type) {
+ auto* var = b.Var(name, ty.ptr(workgroup, store_type));
+ b.RootBlock()->Append(var);
+ return var;
+ }
+};
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NoRootBlock) {
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Return(func);
+ });
+
+ auto* expect = R"(
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b1 {
+ %b1 = block {
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, WorkgroupVarUnused) {
+ MakeVar("wgvar", ty.i32());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, i32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarBool) {
+ auto* var = MakeVar("wgvar", ty.bool_());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:bool = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, false
+ exit_if # if_1
+ }
+ }
+ %5:void = workgroupBarrier
+ %6:bool = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarI32) {
+ auto* var = MakeVar("wgvar", ty.i32());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, i32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:i32 = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, i32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, 0i
+ exit_if # if_1
+ }
+ }
+ %5:void = workgroupBarrier
+ %6:i32 = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarU32) {
+ auto* var = MakeVar("wgvar", ty.u32());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, u32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:u32 = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, u32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, 0u
+ exit_if # if_1
+ }
+ }
+ %5:void = workgroupBarrier
+ %6:u32 = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarF32) {
+ auto* var = MakeVar("wgvar", ty.f32());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, f32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:f32 = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, f32, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, 0.0f
+ exit_if # if_1
+ }
+ }
+ %5:void = workgroupBarrier
+ %6:f32 = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ScalarF16) {
+ auto* var = MakeVar("wgvar", ty.f16());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, f16, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:f16 = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, f16, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, 0.0h
+ exit_if # if_1
+ }
+ }
+ %5:void = workgroupBarrier
+ %6:f16 = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, AtomicI32) {
+ auto* var = MakeVar("wgvar", ty.atomic<i32>());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Call(ty.i32(), core::BuiltinFn::kAtomicLoad, var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, atomic<i32>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:i32 = atomicLoad %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, atomic<i32>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ %5:void = atomicStore %wgvar, 0i
+ exit_if # if_1
+ }
+ }
+ %6:void = workgroupBarrier
+ %7:i32 = atomicLoad %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, AtomicU32) {
+ auto* var = MakeVar("wgvar", ty.atomic<u32>());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Call(ty.u32(), core::BuiltinFn::kAtomicLoad, var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, atomic<u32>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:u32 = atomicLoad %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, atomic<u32>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ %5:void = atomicStore %wgvar, 0u
+ exit_if # if_1
+ }
+ }
+ %6:void = workgroupBarrier
+ %7:u32 = atomicLoad %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ArrayOfI32) {
+ auto* var = MakeVar("wgvar", ty.array<i32, 4>());
+
+ auto* func = MakeEntryPoint("main", 11, 2, 3);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<i32, 4>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func():void -> %b2 {
+ %b2 = block {
+ %3:array<i32, 4> = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<i32, 4>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ loop [i: %b3, b: %b4, c: %b5] { # loop_1
+ %b3 = block { # initializer
+ next_iteration %b4 %tint_local_index
+ }
+ %b4 = block (%4:u32) { # body
+ %5:bool = gt %4:u32, 3u
+ if %5 [t: %b6] { # if_1
+ %b6 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %6:ptr<workgroup, i32, read_write> = access %wgvar, %4:u32
+ store %6, 0i
+ continue %b5
+ }
+ %b5 = block { # continuing
+ %7:u32 = add %4:u32, 66u
+ next_iteration %b4 %7
+ }
+ }
+ %8:void = workgroupBarrier
+ %9:array<i32, 4> = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ArrayOfArrayOfU32) {
+ auto* var = MakeVar("wgvar", ty.array(ty.array<u32, 5>(), 7));
+
+ auto* func = MakeEntryPoint("main", 11, 2, 3);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<u32, 5>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func():void -> %b2 {
+ %b2 = block {
+ %3:array<array<u32, 5>, 7> = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<u32, 5>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ loop [i: %b3, b: %b4, c: %b5] { # loop_1
+ %b3 = block { # initializer
+ next_iteration %b4 %tint_local_index
+ }
+ %b4 = block (%4:u32) { # body
+ %5:bool = gt %4:u32, 34u
+ if %5 [t: %b6] { # if_1
+ %b6 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %6:u32 = mod %4:u32, 5u
+ %7:u32 = div %4:u32, 5u
+ %8:ptr<workgroup, u32, read_write> = access %wgvar, %7, %6
+ store %8, 0u
+ continue %b5
+ }
+ %b5 = block { # continuing
+ %9:u32 = add %4:u32, 66u
+ next_iteration %b4 %9
+ }
+ }
+ %10:void = workgroupBarrier
+ %11:array<array<u32, 5>, 7> = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ArrayOfArrayOfArray) {
+ auto* var = MakeVar("wgvar", ty.array(ty.array(ty.array<i32, 7>(), 5), 3));
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<array<i32, 7>, 5>, 3>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:array<array<array<i32, 7>, 5>, 3> = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<array<i32, 7>, 5>, 3>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ loop [i: %b3, b: %b4, c: %b5] { # loop_1
+ %b3 = block { # initializer
+ next_iteration %b4 %tint_local_index
+ }
+ %b4 = block (%4:u32) { # body
+ %5:bool = gt %4:u32, 104u
+ if %5 [t: %b6] { # if_1
+ %b6 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %6:u32 = mod %4:u32, 7u
+ %7:u32 = div %4:u32, 7u
+ %8:u32 = mod %7, 5u
+ %9:u32 = div %4:u32, 35u
+ %10:ptr<workgroup, i32, read_write> = access %wgvar, %9, %8, %6
+ store %10, 0i
+ continue %b5
+ }
+ %b5 = block { # continuing
+ %11:u32 = add %4:u32, 1u
+ next_iteration %b4 %11
+ }
+ }
+ %12:void = workgroupBarrier
+ %13:array<array<array<i32, 7>, 5>, 3> = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedArrayInnerSizeOne) {
+ auto* var = MakeVar("wgvar", ty.array(ty.array(ty.array<i32, 1>(), 5), 3));
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<array<i32, 1>, 5>, 3>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:array<array<array<i32, 1>, 5>, 3> = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<array<i32, 1>, 5>, 3>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ loop [i: %b3, b: %b4, c: %b5] { # loop_1
+ %b3 = block { # initializer
+ next_iteration %b4 %tint_local_index
+ }
+ %b4 = block (%4:u32) { # body
+ %5:bool = gt %4:u32, 14u
+ if %5 [t: %b6] { # if_1
+ %b6 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %6:u32 = mod %4:u32, 5u
+ %7:u32 = div %4:u32, 5u
+ %8:ptr<workgroup, i32, read_write> = access %wgvar, %7, %6, 0u
+ store %8, 0i
+ continue %b5
+ }
+ %b5 = block { # continuing
+ %9:u32 = add %4:u32, 1u
+ next_iteration %b4 %9
+ }
+ }
+ %10:void = workgroupBarrier
+ %11:array<array<array<i32, 1>, 5>, 3> = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedArrayMiddleSizeOne) {
+ auto* var = MakeVar("wgvar", ty.array(ty.array(ty.array<i32, 3>(), 1), 5));
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<array<i32, 3>, 1>, 5>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:array<array<array<i32, 3>, 1>, 5> = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<array<i32, 3>, 1>, 5>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ loop [i: %b3, b: %b4, c: %b5] { # loop_1
+ %b3 = block { # initializer
+ next_iteration %b4 %tint_local_index
+ }
+ %b4 = block (%4:u32) { # body
+ %5:bool = gt %4:u32, 14u
+ if %5 [t: %b6] { # if_1
+ %b6 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %6:u32 = mod %4:u32, 3u
+ %7:u32 = div %4:u32, 3u
+ %8:ptr<workgroup, i32, read_write> = access %wgvar, %7, 0u, %6
+ store %8, 0i
+ continue %b5
+ }
+ %b5 = block { # continuing
+ %9:u32 = add %4:u32, 1u
+ next_iteration %b4 %9
+ }
+ }
+ %10:void = workgroupBarrier
+ %11:array<array<array<i32, 3>, 1>, 5> = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedArrayOuterSizeOne) {
+ auto* var = MakeVar("wgvar", ty.array(ty.array(ty.array<i32, 3>(), 5), 1));
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<array<i32, 3>, 5>, 1>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:array<array<array<i32, 3>, 5>, 1> = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<array<i32, 3>, 5>, 1>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ loop [i: %b3, b: %b4, c: %b5] { # loop_1
+ %b3 = block { # initializer
+ next_iteration %b4 %tint_local_index
+ }
+ %b4 = block (%4:u32) { # body
+ %5:bool = gt %4:u32, 14u
+ if %5 [t: %b6] { # if_1
+ %b6 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %6:u32 = mod %4:u32, 3u
+ %7:u32 = div %4:u32, 3u
+ %8:ptr<workgroup, i32, read_write> = access %wgvar, 0u, %7, %6
+ store %8, 0i
+ continue %b5
+ }
+ %b5 = block { # continuing
+ %9:u32 = add %4:u32, 1u
+ next_iteration %b4 %9
+ }
+ }
+ %10:void = workgroupBarrier
+ %11:array<array<array<i32, 3>, 5>, 1> = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedArrayTotalSizeOne) {
+ auto* var = MakeVar("wgvar", ty.array(ty.array<i32, 1>(), 1));
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<i32, 1>, 1>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:array<array<i32, 1>, 1> = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<array<i32, 1>, 1>, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ %5:ptr<workgroup, i32, read_write> = access %wgvar, 0u, 0u
+ store %5, 0i
+ exit_if # if_1
+ }
+ }
+ %6:void = workgroupBarrier
+ %7:array<array<i32, 1>, 1> = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, StructOfScalars) {
+ auto* s = ty.Struct(mod.symbols.New("MyStruct"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.u32()},
+ {mod.symbols.New("c"), ty.f32()},
+ });
+ auto* var = MakeVar("wgvar", s);
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+MyStruct = struct @align(4) {
+ a:i32 @offset(0)
+ b:u32 @offset(4)
+ c:f32 @offset(8)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, MyStruct, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:MyStruct = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+MyStruct = struct @align(4) {
+ a:i32 @offset(0)
+ b:u32 @offset(4)
+ c:f32 @offset(8)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, MyStruct, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, MyStruct(0i, 0u, 0.0f)
+ exit_if # if_1
+ }
+ }
+ %5:void = workgroupBarrier
+ %6:MyStruct = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedStructOfScalars) {
+ auto* inner = ty.Struct(mod.symbols.New("Inner"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.u32()},
+ });
+ auto* outer = ty.Struct(mod.symbols.New("Outer"), {
+ {mod.symbols.New("c"), ty.f32()},
+ {mod.symbols.New("inner"), inner},
+ {mod.symbols.New("d"), ty.bool_()},
+ });
+ auto* var = MakeVar("wgvar", outer);
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+Inner = struct @align(4) {
+ a:i32 @offset(0)
+ b:u32 @offset(4)
+}
+
+Outer = struct @align(4) {
+ c:f32 @offset(0)
+ inner:Inner @offset(4)
+ d:bool @offset(12)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, Outer, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:Outer = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Inner = struct @align(4) {
+ a:i32 @offset(0)
+ b:u32 @offset(4)
+}
+
+Outer = struct @align(4) {
+ c:f32 @offset(0)
+ inner:Inner @offset(4)
+ d:bool @offset(12)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, Outer, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, Outer(0.0f, Inner(0i, 0u), false)
+ exit_if # if_1
+ }
+ }
+ %5:void = workgroupBarrier
+ %6:Outer = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, NestedStructOfScalarsWithAtomic) {
+ auto* inner = ty.Struct(mod.symbols.New("Inner"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.atomic<u32>()},
+ });
+ auto* outer = ty.Struct(mod.symbols.New("Outer"), {
+ {mod.symbols.New("c"), ty.f32()},
+ {mod.symbols.New("inner"), inner},
+ {mod.symbols.New("d"), ty.bool_()},
+ });
+ auto* var = MakeVar("wgvar", outer);
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+Inner = struct @align(4) {
+ a:i32 @offset(0)
+ b:atomic<u32> @offset(4)
+}
+
+Outer = struct @align(4) {
+ c:f32 @offset(0)
+ inner:Inner @offset(4)
+ d:bool @offset(12)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, Outer, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ %3:Outer = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Inner = struct @align(4) {
+ a:i32 @offset(0)
+ b:atomic<u32> @offset(4)
+}
+
+Outer = struct @align(4) {
+ c:f32 @offset(0)
+ inner:Inner @offset(4)
+ d:bool @offset(12)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, Outer, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ %5:ptr<workgroup, f32, read_write> = access %wgvar, 0u
+ store %5, 0.0f
+ %6:ptr<workgroup, i32, read_write> = access %wgvar, 1u, 0u
+ store %6, 0i
+ %7:ptr<workgroup, atomic<u32>, read_write> = access %wgvar, 1u, 1u
+ %8:void = atomicStore %7, 0u
+ %9:ptr<workgroup, bool, read_write> = access %wgvar, 2u
+ store %9, false
+ exit_if # if_1
+ }
+ }
+ %10:void = workgroupBarrier
+ %11:Outer = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ArrayOfStructOfArrayOfStructWithAtomic) {
+ auto* inner = ty.Struct(mod.symbols.New("Inner"), {
+ {mod.symbols.New("a"), ty.i32()},
+ {mod.symbols.New("b"), ty.atomic<u32>()},
+ });
+ auto* outer =
+ ty.Struct(mod.symbols.New("Outer"), {
+ {mod.symbols.New("c"), ty.f32()},
+ {mod.symbols.New("inner"), ty.array(inner, 13)},
+ {mod.symbols.New("d"), ty.bool_()},
+ });
+ auto* var = MakeVar("wgvar", ty.array(outer, 7));
+
+ auto* func = MakeEntryPoint("main", 7, 3, 2);
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+Inner = struct @align(4) {
+ a:i32 @offset(0)
+ b:atomic<u32> @offset(4)
+}
+
+Outer = struct @align(4) {
+ c:f32 @offset(0)
+ inner:array<Inner, 13> @offset(4)
+ d:bool @offset(108)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<Outer, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(7, 3, 2) func():void -> %b2 {
+ %b2 = block {
+ %3:array<Outer, 7> = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+Inner = struct @align(4) {
+ a:i32 @offset(0)
+ b:atomic<u32> @offset(4)
+}
+
+Outer = struct @align(4) {
+ c:f32 @offset(0)
+ inner:array<Inner, 13> @offset(4)
+ d:bool @offset(108)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, array<Outer, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(7, 3, 2) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ loop [i: %b3, b: %b4, c: %b5] { # loop_1
+ %b3 = block { # initializer
+ next_iteration %b4 %tint_local_index
+ }
+ %b4 = block (%4:u32) { # body
+ %5:bool = gt %4:u32, 6u
+ if %5 [t: %b6] { # if_1
+ %b6 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %6:ptr<workgroup, f32, read_write> = access %wgvar, %4:u32, 0u
+ store %6, 0.0f
+ %7:ptr<workgroup, bool, read_write> = access %wgvar, %4:u32, 2u
+ store %7, false
+ continue %b5
+ }
+ %b5 = block { # continuing
+ %8:u32 = add %4:u32, 42u
+ next_iteration %b4 %8
+ }
+ }
+ loop [i: %b7, b: %b8, c: %b9] { # loop_2
+ %b7 = block { # initializer
+ next_iteration %b8 %tint_local_index
+ }
+ %b8 = block (%9:u32) { # body
+ %10:bool = gt %9:u32, 90u
+ if %10 [t: %b10] { # if_2
+ %b10 = block { # true
+ exit_loop # loop_2
+ }
+ }
+ %11:u32 = mod %9:u32, 13u
+ %12:u32 = div %9:u32, 13u
+ %13:ptr<workgroup, i32, read_write> = access %wgvar, %12, 1u, %11, 0u
+ store %13, 0i
+ %14:u32 = mod %9:u32, 13u
+ %15:u32 = div %9:u32, 13u
+ %16:ptr<workgroup, atomic<u32>, read_write> = access %wgvar, %15, 1u, %14, 1u
+ %17:void = atomicStore %16, 0u
+ continue %b9
+ }
+ %b9 = block { # continuing
+ %18:u32 = add %9:u32, 42u
+ next_iteration %b8 %18
+ }
+ }
+ %19:void = workgroupBarrier
+ %20:array<Outer, 7> = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, MultipleVariables_DifferentIterationCounts) {
+ auto* var_a = MakeVar("var_a", ty.bool_());
+ auto* var_b = MakeVar("var_b", ty.array<i32, 4>());
+ auto* var_c = MakeVar("var_c", ty.array(ty.array<u32, 5>(), 7));
+
+ auto* func = MakeEntryPoint("main", 11, 2, 3);
+ b.Append(func->Block(), [&] { //
+ b.Load(var_a);
+ b.Load(var_b);
+ b.Load(var_c);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %var_a:ptr<workgroup, bool, read_write> = var
+ %var_b:ptr<workgroup, array<i32, 4>, read_write> = var
+ %var_c:ptr<workgroup, array<array<u32, 5>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func():void -> %b2 {
+ %b2 = block {
+ %5:bool = load %var_a
+ %6:array<i32, 4> = load %var_b
+ %7:array<array<u32, 5>, 7> = load %var_c
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %var_a:ptr<workgroup, bool, read_write> = var
+ %var_b:ptr<workgroup, array<i32, 4>, read_write> = var
+ %var_c:ptr<workgroup, array<array<u32, 5>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %6:bool = eq %tint_local_index, 0u
+ if %6 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %var_a, false
+ exit_if # if_1
+ }
+ }
+ loop [i: %b4, b: %b5, c: %b6] { # loop_1
+ %b4 = block { # initializer
+ next_iteration %b5 %tint_local_index
+ }
+ %b5 = block (%7:u32) { # body
+ %8:bool = gt %7:u32, 3u
+ if %8 [t: %b7] { # if_2
+ %b7 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %9:ptr<workgroup, i32, read_write> = access %var_b, %7:u32
+ store %9, 0i
+ continue %b6
+ }
+ %b6 = block { # continuing
+ %10:u32 = add %7:u32, 66u
+ next_iteration %b5 %10
+ }
+ }
+ loop [i: %b8, b: %b9, c: %b10] { # loop_2
+ %b8 = block { # initializer
+ next_iteration %b9 %tint_local_index
+ }
+ %b9 = block (%11:u32) { # body
+ %12:bool = gt %11:u32, 34u
+ if %12 [t: %b11] { # if_3
+ %b11 = block { # true
+ exit_loop # loop_2
+ }
+ }
+ %13:u32 = mod %11:u32, 5u
+ %14:u32 = div %11:u32, 5u
+ %15:ptr<workgroup, u32, read_write> = access %var_c, %14, %13
+ store %15, 0u
+ continue %b10
+ }
+ %b10 = block { # continuing
+ %16:u32 = add %11:u32, 66u
+ next_iteration %b9 %16
+ }
+ }
+ %17:void = workgroupBarrier
+ %18:bool = load %var_a
+ %19:array<i32, 4> = load %var_b
+ %20:array<array<u32, 5>, 7> = load %var_c
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, MultipleVariables_SharedIterationCounts) {
+ auto* var_a = MakeVar("var_a", ty.bool_());
+ auto* var_b = MakeVar("var_b", ty.i32());
+ auto* var_c = MakeVar("var_c", ty.array<i32, 42>());
+ auto* var_d = MakeVar("var_d", ty.array(ty.array<u32, 6>(), 7));
+
+ auto* func = MakeEntryPoint("main", 11, 2, 3);
+ b.Append(func->Block(), [&] { //
+ b.Load(var_a);
+ b.Load(var_b);
+ b.Load(var_c);
+ b.Load(var_d);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %var_a:ptr<workgroup, bool, read_write> = var
+ %var_b:ptr<workgroup, i32, read_write> = var
+ %var_c:ptr<workgroup, array<i32, 42>, read_write> = var
+ %var_d:ptr<workgroup, array<array<u32, 6>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func():void -> %b2 {
+ %b2 = block {
+ %6:bool = load %var_a
+ %7:i32 = load %var_b
+ %8:array<i32, 42> = load %var_c
+ %9:array<array<u32, 6>, 7> = load %var_d
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %var_a:ptr<workgroup, bool, read_write> = var
+ %var_b:ptr<workgroup, i32, read_write> = var
+ %var_c:ptr<workgroup, array<i32, 42>, read_write> = var
+ %var_d:ptr<workgroup, array<array<u32, 6>, 7>, read_write> = var
+}
+
+%main = @compute @workgroup_size(11, 2, 3) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %7:bool = eq %tint_local_index, 0u
+ if %7 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %var_a, false
+ store %var_b, 0i
+ exit_if # if_1
+ }
+ }
+ loop [i: %b4, b: %b5, c: %b6] { # loop_1
+ %b4 = block { # initializer
+ next_iteration %b5 %tint_local_index
+ }
+ %b5 = block (%8:u32) { # body
+ %9:bool = gt %8:u32, 41u
+ if %9 [t: %b7] { # if_2
+ %b7 = block { # true
+ exit_loop # loop_1
+ }
+ }
+ %10:ptr<workgroup, i32, read_write> = access %var_c, %8:u32
+ store %10, 0i
+ %11:u32 = mod %8:u32, 6u
+ %12:u32 = div %8:u32, 6u
+ %13:ptr<workgroup, u32, read_write> = access %var_d, %12, %11
+ store %13, 0u
+ continue %b6
+ }
+ %b6 = block { # continuing
+ %14:u32 = add %8:u32, 66u
+ next_iteration %b5 %14
+ }
+ }
+ %15:void = workgroupBarrier
+ %16:bool = load %var_a
+ %17:i32 = load %var_b
+ %18:array<i32, 42> = load %var_c
+ %19:array<array<u32, 6>, 7> = load %var_d
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ExistingLocalInvocationIndex) {
+ auto* var = MakeVar("wgvar", ty.bool_());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ auto* global_id = b.FunctionParam("global_id", ty.vec4<u32>());
+ global_id->SetBuiltin(FunctionParam::Builtin::kGlobalInvocationId);
+ auto* index = b.FunctionParam("index", ty.u32());
+ index->SetBuiltin(FunctionParam::Builtin::kLocalInvocationIndex);
+ func->SetParams({global_id, index});
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%global_id:vec4<u32> [@global_invocation_id], %index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %5:bool = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%global_id:vec4<u32> [@global_invocation_id], %index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %5:bool = eq %index, 0u
+ if %5 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, false
+ exit_if # if_1
+ }
+ }
+ %6:void = workgroupBarrier
+ %7:bool = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, ExistingLocalInvocationIndexInStruct) {
+ auto* var = MakeVar("wgvar", ty.bool_());
+
+ auto* structure =
+ ty.Struct(mod.symbols.New("MyStruct"),
+ {
+ {
+ mod.symbols.New("global_id"),
+ ty.vec3<u32>(),
+ {{}, {}, core::BuiltinValue::kGlobalInvocationId, {}, false},
+ },
+ {
+ mod.symbols.New("index"),
+ ty.u32(),
+ {{}, {}, core::BuiltinValue::kLocalInvocationIndex, {}, false},
+ },
+ });
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ func->SetParams({b.FunctionParam("params", structure)});
+ b.Append(func->Block(), [&] { //
+ b.Load(var);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+MyStruct = struct @align(16) {
+ global_id:vec3<u32> @offset(0), @builtin(global_invocation_id)
+ index:u32 @offset(12), @builtin(local_invocation_index)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%params:MyStruct):void -> %b2 {
+ %b2 = block {
+ %4:bool = load %wgvar
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+MyStruct = struct @align(16) {
+ global_id:vec3<u32> @offset(0), @builtin(global_invocation_id)
+ index:u32 @offset(12), @builtin(local_invocation_index)
+}
+
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%params:MyStruct):void -> %b2 {
+ %b2 = block {
+ %4:u32 = access %params, 1u
+ %5:bool = eq %4, 0u
+ if %5 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, false
+ exit_if # if_1
+ }
+ }
+ %6:void = workgroupBarrier
+ %7:bool = load %wgvar
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, UseInsideNestedBlock) {
+ auto* var = MakeVar("wgvar", ty.bool_());
+
+ auto* func = MakeEntryPoint("main", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ auto* ifelse = b.If(true);
+ b.Append(ifelse->True(), [&] { //
+ auto* sw = b.Switch(42_i);
+ auto* def_case = b.Case(sw, Vector{core::ir::Switch::CaseSelector()});
+ b.Append(def_case, [&] { //
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { //
+ b.Continue(loop);
+ b.Append(loop->Continuing(), [&] { //
+ auto* load = b.Load(var);
+ b.BreakIf(loop, load);
+ });
+ });
+ b.ExitSwitch(sw);
+ });
+ b.ExitIf(ifelse);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func():void -> %b2 {
+ %b2 = block {
+ if true [t: %b3] { # if_1
+ %b3 = block { # true
+ switch 42i [c: (default, %b4)] { # switch_1
+ %b4 = block { # case
+ loop [b: %b5, c: %b6] { # loop_1
+ %b5 = block { # body
+ continue %b6
+ }
+ %b6 = block { # continuing
+ %3:bool = load %wgvar
+ break_if %3 %b5
+ }
+ }
+ exit_switch # switch_1
+ }
+ }
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%main = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b2 {
+ %b2 = block {
+ %4:bool = eq %tint_local_index, 0u
+ if %4 [t: %b3] { # if_1
+ %b3 = block { # true
+ store %wgvar, false
+ exit_if # if_1
+ }
+ }
+ %5:void = workgroupBarrier
+ if true [t: %b4] { # if_2
+ %b4 = block { # true
+ switch 42i [c: (default, %b5)] { # switch_1
+ %b5 = block { # case
+ loop [b: %b6, c: %b7] { # loop_1
+ %b6 = block { # body
+ continue %b7
+ }
+ %b7 = block { # continuing
+ %6:bool = load %wgvar
+ break_if %6 %b6
+ }
+ }
+ exit_switch # switch_1
+ }
+ }
+ exit_if # if_2
+ }
+ }
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, UseInsideIndirectFunctionCall) {
+ auto* var = MakeVar("wgvar", ty.bool_());
+
+ auto* foo = b.Function("foo", ty.void_());
+ b.Append(foo->Block(), [&] { //
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { //
+ b.Continue(loop);
+ b.Append(loop->Continuing(), [&] { //
+ auto* load = b.Load(var);
+ b.BreakIf(loop, load);
+ });
+ });
+ b.Return(foo);
+ });
+
+ auto* bar = b.Function("foo", ty.void_());
+ b.Append(bar->Block(), [&] { //
+ auto* ifelse = b.If(true);
+ b.Append(ifelse->True(), [&] { //
+ b.Call(ty.void_(), foo);
+ b.ExitIf(ifelse);
+ });
+ b.Return(bar);
+ });
+
+ auto* func = MakeEntryPoint("func", 1, 1, 1);
+ b.Append(func->Block(), [&] { //
+ auto* ifelse = b.If(true);
+ b.Append(ifelse->True(), [&] { //
+ auto* sw = b.Switch(42_i);
+ auto* def_case = b.Case(sw, Vector{core::ir::Switch::CaseSelector()});
+ b.Append(def_case, [&] { //
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { //
+ b.Continue(loop);
+ b.Append(loop->Continuing(), [&] { //
+ b.Call(ty.void_(), bar);
+ b.BreakIf(loop, true);
+ });
+ });
+ b.ExitSwitch(sw);
+ });
+ b.ExitIf(ifelse);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ loop [b: %b3, c: %b4] { # loop_1
+ %b3 = block { # body
+ continue %b4
+ }
+ %b4 = block { # continuing
+ %3:bool = load %wgvar
+ break_if %3 %b3
+ }
+ }
+ ret
+ }
+}
+%foo_1 = func():void -> %b5 { # %foo_1: 'foo'
+ %b5 = block {
+ if true [t: %b6] { # if_1
+ %b6 = block { # true
+ %5:void = call %foo
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+%func = @compute @workgroup_size(1, 1, 1) func():void -> %b7 {
+ %b7 = block {
+ if true [t: %b8] { # if_2
+ %b8 = block { # true
+ switch 42i [c: (default, %b9)] { # switch_1
+ %b9 = block { # case
+ loop [b: %b10, c: %b11] { # loop_2
+ %b10 = block { # body
+ continue %b11
+ }
+ %b11 = block { # continuing
+ %7:void = call %foo_1
+ break_if true %b10
+ }
+ }
+ exit_switch # switch_1
+ }
+ }
+ exit_if # if_2
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ loop [b: %b3, c: %b4] { # loop_1
+ %b3 = block { # body
+ continue %b4
+ }
+ %b4 = block { # continuing
+ %3:bool = load %wgvar
+ break_if %3 %b3
+ }
+ }
+ ret
+ }
+}
+%foo_1 = func():void -> %b5 { # %foo_1: 'foo'
+ %b5 = block {
+ if true [t: %b6] { # if_1
+ %b6 = block { # true
+ %5:void = call %foo
+ exit_if # if_1
+ }
+ }
+ ret
+ }
+}
+%func = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b7 {
+ %b7 = block {
+ %8:bool = eq %tint_local_index, 0u
+ if %8 [t: %b8] { # if_2
+ %b8 = block { # true
+ store %wgvar, false
+ exit_if # if_2
+ }
+ }
+ %9:void = workgroupBarrier
+ if true [t: %b9] { # if_3
+ %b9 = block { # true
+ switch 42i [c: (default, %b10)] { # switch_1
+ %b10 = block { # case
+ loop [b: %b11, c: %b12] { # loop_2
+ %b11 = block { # body
+ continue %b12
+ }
+ %b12 = block { # continuing
+ %10:void = call %foo_1
+ break_if true %b11
+ }
+ }
+ exit_switch # switch_1
+ }
+ }
+ exit_if # if_3
+ }
+ }
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(IR_ZeroInitWorkgroupMemoryTest, MultipleEntryPoints_SameVarViaHelper) {
+ auto* var = MakeVar("wgvar", ty.bool_());
+
+ auto* foo = b.Function("foo", ty.void_());
+ b.Append(foo->Block(), [&] { //
+ auto* loop = b.Loop();
+ b.Append(loop->Body(), [&] { //
+ b.Continue(loop);
+ b.Append(loop->Continuing(), [&] { //
+ auto* load = b.Load(var);
+ b.BreakIf(loop, load);
+ });
+ });
+ b.Return(foo);
+ });
+
+ auto* ep1 = MakeEntryPoint("ep1", 1, 1, 1);
+ b.Append(ep1->Block(), [&] { //
+ b.Call(ty.void_(), foo);
+ b.Return(ep1);
+ });
+
+ auto* ep2 = MakeEntryPoint("ep2", 1, 1, 1);
+ b.Append(ep2->Block(), [&] { //
+ b.Call(ty.void_(), foo);
+ b.Return(ep2);
+ });
+
+ auto* src = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ loop [b: %b3, c: %b4] { # loop_1
+ %b3 = block { # body
+ continue %b4
+ }
+ %b4 = block { # continuing
+ %3:bool = load %wgvar
+ break_if %3 %b3
+ }
+ }
+ ret
+ }
+}
+%ep1 = @compute @workgroup_size(1, 1, 1) func():void -> %b5 {
+ %b5 = block {
+ %5:void = call %foo
+ ret
+ }
+}
+%ep2 = @compute @workgroup_size(1, 1, 1) func():void -> %b6 {
+ %b6 = block {
+ %7:void = call %foo
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%b1 = block { # root
+ %wgvar:ptr<workgroup, bool, read_write> = var
+}
+
+%foo = func():void -> %b2 {
+ %b2 = block {
+ loop [b: %b3, c: %b4] { # loop_1
+ %b3 = block { # body
+ continue %b4
+ }
+ %b4 = block { # continuing
+ %3:bool = load %wgvar
+ break_if %3 %b3
+ }
+ }
+ ret
+ }
+}
+%ep1 = @compute @workgroup_size(1, 1, 1) func(%tint_local_index:u32 [@local_invocation_index]):void -> %b5 {
+ %b5 = block {
+ %6:bool = eq %tint_local_index, 0u
+ if %6 [t: %b6] { # if_1
+ %b6 = block { # true
+ store %wgvar, false
+ exit_if # if_1
+ }
+ }
+ %7:void = workgroupBarrier
+ %8:void = call %foo
+ ret
+ }
+}
+%ep2 = @compute @workgroup_size(1, 1, 1) func(%tint_local_index_1:u32 [@local_invocation_index]):void -> %b7 { # %tint_local_index_1: 'tint_local_index'
+ %b7 = block {
+ %11:bool = eq %tint_local_index_1, 0u
+ if %11 [t: %b8] { # if_2
+ %b8 = block { # true
+ store %wgvar, false
+ exit_if # if_2
+ }
+ }
+ %12:void = workgroupBarrier
+ %13:void = call %foo
+ ret
+ }
+}
+)";
+
+ Run(ZeroInitWorkgroupMemory);
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::core::ir::transform
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index 6a5c39f..93989a4 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -25,6 +25,7 @@
#include "src/tint/lang/core/ir/transform/multiplanar_external_texture.h"
#include "src/tint/lang/core/ir/transform/robustness.h"
#include "src/tint/lang/core/ir/transform/std140.h"
+#include "src/tint/lang/core/ir/transform/zero_init_workgroup_memory.h"
#include "src/tint/lang/spirv/writer/raise/builtin_polyfill.h"
#include "src/tint/lang/spirv/writer/raise/expand_implicit_splats.h"
#include "src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.h"
@@ -70,6 +71,11 @@
RUN_TRANSFORM(core::ir::transform::MultiplanarExternalTexture, module,
options.external_texture_options);
+ if (!options.disable_workgroup_init &&
+ !options.use_zero_initialize_workgroup_memory_extension) {
+ RUN_TRANSFORM(core::ir::transform::ZeroInitWorkgroupMemory, module);
+ }
+
RUN_TRANSFORM(core::ir::transform::AddEmptyEntryPoint, module);
RUN_TRANSFORM(core::ir::transform::Bgra8UnormPolyfill, module);
RUN_TRANSFORM(core::ir::transform::BlockDecoratedStructs, module);