[tint] Polyfill case switch with if
Bug: 443906252
Change-Id: Ibec456e950717af8b6a2e3ff192291703026f7d9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/262994
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Peter McNeeley <petermcneeley@google.com>
(cherry picked from commit 81eda65489d8a1781ee4a076f8a3d2dceda03835)
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/268734
diff --git a/src/dawn/native/Toggles.cpp b/src/dawn/native/Toggles.cpp
index 5869846..37b586f 100644
--- a/src/dawn/native/Toggles.cpp
+++ b/src/dawn/native/Toggles.cpp
@@ -566,6 +566,10 @@
"and unpack4xU8() on D3D12 backends. Note that these functions are always polyfilled on all "
"other backends right now.",
"https://crbug.com/tint/1497", ToggleStage::Device}},
+ {Toggle::VulkanPolyfillSwitchWithIf,
+ {"vulkan_polyfill_switch_with_if",
+ "Polyfill switch statements with if/else statements on Vulkan.",
+ "https://crbug.com/443906252", ToggleStage::Device}},
{Toggle::ExposeWGSLTestingFeatures,
{"expose_wgsl_testing_features",
"Make the Instance expose the ChromiumTesting* features for testing of "
diff --git a/src/dawn/native/Toggles.h b/src/dawn/native/Toggles.h
index d57717d..33bc0c4 100644
--- a/src/dawn/native/Toggles.h
+++ b/src/dawn/native/Toggles.h
@@ -138,6 +138,7 @@
PolyfillPackUnpack4x8Norm,
EnableSubgroupsIntelGen9,
D3D12PolyFillPackUnpack4x8,
+ VulkanPolyfillSwitchWithIf,
ExposeWGSLTestingFeatures,
ExposeWGSLExperimentalFeatures,
DisablePolyfillsOnIntegerDivisonAndModulo,
diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
index f8d532a..e6b5412 100644
--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
+++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
@@ -974,6 +974,12 @@
deviceToggles->Default(Toggle::VulkanDirectVariableAccessTransformHandle, true);
}
+ // Pixel 10 is the only device seen with PowerVR D-Series DXT-48-1536 so far.
+ if (gpu_info::IsImgTec(GetVendorId()) && GetDeviceId() == 0x71061212) {
+ // crbug.com/443906252: Polyfill for case switch with large ranges.
+ deviceToggles->Default(Toggle::VulkanPolyfillSwitchWithIf, true);
+ }
+
if (IsAndroidARM()) {
// dawn:1550: Resolving multiple color targets in a single pass fails on ARM GPUs. To
// work around the issue, passes that resolve to multiple color targets will instead be
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index ec42b72..12f26fc 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -302,6 +302,8 @@
GetDevice()->IsToggleEnabled(Toggle::PolyFillPacked4x8DotProduct);
req.tintOptions.polyfill_pack_unpack_4x8_norm =
GetDevice()->IsToggleEnabled(Toggle::PolyfillPackUnpack4x8Norm);
+ req.tintOptions.polyfill_case_switch =
+ GetDevice()->IsToggleEnabled(Toggle::VulkanPolyfillSwitchWithIf);
req.tintOptions.polyfill_subgroup_broadcast_f16 =
GetDevice()->IsToggleEnabled(Toggle::EnableSubgroupsIntelGen9);
req.tintOptions.disable_polyfill_integer_div_mod =
diff --git a/src/dawn/tests/end2end/PolyfillBuiltinSimpleTests.cpp b/src/dawn/tests/end2end/PolyfillBuiltinSimpleTests.cpp
index 53ca6ac..97a7b3d 100644
--- a/src/dawn/tests/end2end/PolyfillBuiltinSimpleTests.cpp
+++ b/src/dawn/tests/end2end/PolyfillBuiltinSimpleTests.cpp
@@ -229,6 +229,142 @@
EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), output, 0, expected.size());
}
+TEST_P(PolyfillBuiltinSimpleTests, CaseSwitchToIf) {
+ std::string kShaderCode = R"(
+ struct Data { values: array<i32> };
+ @group(0) @binding(0) var<storage, read> input_data: Data;
+ @group(0) @binding(1) var<storage, read_write> output_data: Data;
+
+ @compute @workgroup_size(4)
+ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
+ var input_ = input_data.values[global_id.x];
+ var ret = 0i;
+ switch( input_ ) {
+ case 1: {
+ ret = 3;
+ }
+ case 2:{
+ ret = 7;
+ }
+ case -2147483648:{
+ ret = 71;
+ }
+ case 123, 87:{
+ ret = 11;
+ }
+ case -1:{
+ ret = 33;
+ }
+ default {
+ ret = 82;
+ }
+ }
+ output_data.values[global_id.x] = ret;
+ }
+ )";
+
+ wgpu::ComputePipeline pipeline = CreateComputePipeline(kShaderCode);
+ uint32_t kDefaultVal = 0;
+ std::vector<uint32_t> init_input = {uint32_t(std::numeric_limits<int32_t>::lowest()),
+ uint32_t(-15), 17, 123};
+
+ wgpu::Buffer input = CreateBuffer(init_input);
+ wgpu::Buffer output = CreateBuffer(4, kDefaultVal);
+ wgpu::BindGroup bindGroup =
+ utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, input}, {1, output}});
+
+ wgpu::CommandBuffer commands;
+ {
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+ pass.SetPipeline(pipeline);
+ pass.SetBindGroup(0, bindGroup);
+ pass.DispatchWorkgroups(64);
+ pass.End();
+ commands = encoder.Finish();
+ }
+
+ queue.Submit(1, &commands);
+ std::vector<uint32_t> expected = {71, 82, 82, 11};
+
+ EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), output, 0, expected.size());
+}
+
+TEST_P(PolyfillBuiltinSimpleTests, CaseSwitchToIfComplex) {
+ std::string kShaderCode = R"(
+ @group(0) @binding(0) var<storage, read> input_data: array<i32>;
+ @group(0) @binding(1) var<storage, read_write> output_data: array<i32>;
+
+ @compute @workgroup_size(4)
+ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
+ var input_ = input_data[global_id.x];
+ var ret = 0i;
+ switch( input_ ) {
+ case 1: {
+ ret = 3;
+ }
+ case -2:{
+ switch(input_){
+ case 1: {
+ ret = 3;
+ }
+ case -2:{
+ ret = 4;
+ }
+ default{
+ ret = 99;
+ }
+ }
+ break;
+ ret = 7;
+ }
+ case -2147483648:{
+ if(input_ == 17){
+ ret = 71;
+ break;
+ }
+ ret = 13;
+ }
+ case 3, 5:{
+ if(input_ == 3){
+ break;
+ }
+ ret = 11;
+ }
+ default {
+ ret = 82;
+ }
+ }
+ output_data[global_id.x] = ret;
+ }
+ )";
+
+ wgpu::ComputePipeline pipeline = CreateComputePipeline(kShaderCode);
+ uint32_t kDefaultVal = 0;
+ std::vector<uint32_t> init_input = {uint32_t(std::numeric_limits<int32_t>::lowest()),
+ uint32_t(-2), 3, 5};
+ std::vector<uint32_t> expected = {13, 4, 0, 11};
+ wgpu::Buffer input = CreateBuffer(init_input);
+ wgpu::Buffer output = CreateBuffer(4, kDefaultVal);
+ wgpu::BindGroup bindGroup =
+ utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, input}, {1, output}});
+
+ wgpu::CommandBuffer commands;
+ {
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+ pass.SetPipeline(pipeline);
+ pass.SetBindGroup(0, bindGroup);
+ pass.DispatchWorkgroups(64);
+ pass.End();
+ commands = encoder.Finish();
+ }
+
+ queue.Submit(1, &commands);
+
+ EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), output, 0, expected.size());
+}
+
DAWN_INSTANTIATE_TEST(PolyfillBuiltinSimpleTests,
D3D12Backend(),
D3D11Backend(),
@@ -237,6 +373,7 @@
D3D12Backend({"scalarize_max_min_clamp"}),
MetalBackend({"scalarize_max_min_clamp"}),
VulkanBackend({"scalarize_max_min_clamp"}),
+ VulkanBackend({"vulkan_polyfill_switch_with_if"}),
D3D11Backend({"scalarize_max_min_clamp"}),
OpenGLESBackend());
diff --git a/src/tint/lang/core/ir/switch.h b/src/tint/lang/core/ir/switch.h
index b6c44d9..55cf61f 100644
--- a/src/tint/lang/core/ir/switch.h
+++ b/src/tint/lang/core/ir/switch.h
@@ -106,6 +106,13 @@
/// @returns the switch cases
Vector<Case, 4>& Cases() { return cases_; }
+ /// @returns the switch cases by moving them
+ Vector<Case, 4> TakeCases() {
+ auto rtn = std::move(cases_);
+ cases_.Clear();
+ return rtn;
+ }
+
/// @returns the switch cases
VectorRef<Case> Cases() const { return cases_; }
diff --git a/src/tint/lang/spirv/writer/common/options.h b/src/tint/lang/spirv/writer/common/options.h
index 521faf3..c4a98b0 100644
--- a/src/tint/lang/spirv/writer/common/options.h
+++ b/src/tint/lang/spirv/writer/common/options.h
@@ -118,6 +118,9 @@
/// `unpack4x8unorm` builtins
bool polyfill_pack_unpack_4x8_norm = false;
+ /// Set to `true` to generate a polyfill for switch statements using if/else statements.
+ bool polyfill_case_switch = false;
+
/// Set to `true` to generate a polyfill clamp of `id` param of subgroupShuffle to within the
/// spec max subgroup size.
bool subgroup_shuffle_clamped = false;
@@ -163,6 +166,7 @@
pass_matrix_by_pointer,
polyfill_dot_4x8_packed,
polyfill_pack_unpack_4x8_norm,
+ polyfill_case_switch,
subgroup_shuffle_clamped,
polyfill_subgroup_broadcast_f16,
disable_polyfill_integer_div_mod,
diff --git a/src/tint/lang/spirv/writer/raise/BUILD.bazel b/src/tint/lang/spirv/writer/raise/BUILD.bazel
index 6d48251..80e7efa 100644
--- a/src/tint/lang/spirv/writer/raise/BUILD.bazel
+++ b/src/tint/lang/spirv/writer/raise/BUILD.bazel
@@ -40,6 +40,7 @@
name = "raise",
srcs = [
"builtin_polyfill.cc",
+ "case_switch_to_if_else.cc",
"expand_implicit_splats.cc",
"fork_explicit_layout_types.cc",
"handle_matrix_arithmetic.cc",
@@ -54,6 +55,7 @@
],
hdrs = [
"builtin_polyfill.h",
+ "case_switch_to_if_else.h",
"expand_implicit_splats.h",
"fork_explicit_layout_types.h",
"handle_matrix_arithmetic.h",
@@ -78,6 +80,7 @@
"//src/tint/lang/spirv/intrinsic",
"//src/tint/lang/spirv/ir",
"//src/tint/lang/spirv/type",
+ "//src/tint/lang/wgsl/ast",
"//src/tint/utils",
"//src/tint/utils/containers",
"//src/tint/utils/diagnostic",
@@ -108,6 +111,7 @@
alwayslink = True,
srcs = [
"builtin_polyfill_test.cc",
+ "case_switch_to_if_else_test.cc",
"expand_implicit_splats_test.cc",
"fork_explicit_layout_types_test.cc",
"handle_matrix_arithmetic_test.cc",
diff --git a/src/tint/lang/spirv/writer/raise/BUILD.cmake b/src/tint/lang/spirv/writer/raise/BUILD.cmake
index 0580e4c..8150d9e 100644
--- a/src/tint/lang/spirv/writer/raise/BUILD.cmake
+++ b/src/tint/lang/spirv/writer/raise/BUILD.cmake
@@ -43,6 +43,8 @@
tint_add_target(tint_lang_spirv_writer_raise lib
lang/spirv/writer/raise/builtin_polyfill.cc
lang/spirv/writer/raise/builtin_polyfill.h
+ lang/spirv/writer/raise/case_switch_to_if_else.cc
+ lang/spirv/writer/raise/case_switch_to_if_else.h
lang/spirv/writer/raise/expand_implicit_splats.cc
lang/spirv/writer/raise/expand_implicit_splats.h
lang/spirv/writer/raise/fork_explicit_layout_types.cc
@@ -79,6 +81,7 @@
tint_lang_spirv_intrinsic
tint_lang_spirv_ir
tint_lang_spirv_type
+ tint_lang_wgsl_ast
tint_utils
tint_utils_containers
tint_utils_diagnostic
@@ -116,6 +119,7 @@
################################################################################
tint_add_target(tint_lang_spirv_writer_raise_test test
lang/spirv/writer/raise/builtin_polyfill_test.cc
+ lang/spirv/writer/raise/case_switch_to_if_else_test.cc
lang/spirv/writer/raise/expand_implicit_splats_test.cc
lang/spirv/writer/raise/fork_explicit_layout_types_test.cc
lang/spirv/writer/raise/handle_matrix_arithmetic_test.cc
diff --git a/src/tint/lang/spirv/writer/raise/BUILD.gn b/src/tint/lang/spirv/writer/raise/BUILD.gn
index b649e1b..c4c229b 100644
--- a/src/tint/lang/spirv/writer/raise/BUILD.gn
+++ b/src/tint/lang/spirv/writer/raise/BUILD.gn
@@ -47,6 +47,8 @@
sources = [
"builtin_polyfill.cc",
"builtin_polyfill.h",
+ "case_switch_to_if_else.cc",
+ "case_switch_to_if_else.h",
"expand_implicit_splats.cc",
"expand_implicit_splats.h",
"fork_explicit_layout_types.cc",
@@ -83,6 +85,7 @@
"${tint_src_dir}/lang/spirv/intrinsic",
"${tint_src_dir}/lang/spirv/ir",
"${tint_src_dir}/lang/spirv/type",
+ "${tint_src_dir}/lang/wgsl/ast",
"${tint_src_dir}/utils",
"${tint_src_dir}/utils/containers",
"${tint_src_dir}/utils/diagnostic",
@@ -109,6 +112,7 @@
tint_unittests_source_set("unittests") {
sources = [
"builtin_polyfill_test.cc",
+ "case_switch_to_if_else_test.cc",
"expand_implicit_splats_test.cc",
"fork_explicit_layout_types_test.cc",
"handle_matrix_arithmetic_test.cc",
diff --git a/src/tint/lang/spirv/writer/raise/case_switch_to_if_else.cc b/src/tint/lang/spirv/writer/raise/case_switch_to_if_else.cc
new file mode 100644
index 0000000..4026054
--- /dev/null
+++ b/src/tint/lang/spirv/writer/raise/case_switch_to_if_else.cc
@@ -0,0 +1,168 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/spirv/writer/raise/case_switch_to_if_else.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <utility>
+
+#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/control_instruction.h"
+#include "src/tint/lang/core/ir/exit_switch.h"
+#include "src/tint/lang/core/ir/if.h"
+#include "src/tint/lang/core/ir/instruction.h"
+#include "src/tint/lang/core/ir/module.h"
+#include "src/tint/lang/core/ir/switch.h"
+#include "src/tint/lang/core/ir/traverse.h"
+#include "src/tint/lang/core/ir/validator.h"
+#include "src/tint/lang/core/ir/value.h"
+#include "src/tint/lang/core/unary_op.h"
+#include "src/tint/lang/wgsl/ast/case_selector.h"
+#include "src/tint/utils/containers/transform.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+
+namespace tint::spirv::writer::raise {
+
+namespace {
+
+const core::ir::Capabilities kCaseSwitchToIfElseCapabilities{
+ core::ir::Capability::kAllowDuplicateBindings,
+ core::ir::Capability::kAllowAnyInputAttachmentIndexType,
+ core::ir::Capability::kAllowNonCoreTypes,
+};
+
+/// PIMPL state for the transform, for a single function.
+struct State {
+ /// The IR module.
+ core::ir::Module& ir;
+
+ /// The IR builder.
+ core::ir::Builder b{ir};
+
+ /// The type manager.
+ core::type::Manager& ty{ir.Types()};
+
+ /// Process the IR
+ void Process() {
+ Vector<core::ir::Switch*, 4> worklist;
+ for (auto* inst : ir.Instructions()) {
+ if (auto* s = inst->As<core::ir::Switch>()) {
+ // Even though the switch param could be signed we use u32 as that is likely the
+ // internal representation of the compiler.
+ uint32_t max_sel_case = 0u;
+ uint32_t min_sel_case = std::numeric_limits<uint32_t>().max();
+ for (auto& c : s->Cases()) {
+ for (auto& sel : c.selectors) {
+ if (!sel.IsDefault() && sel.val && sel.val->Value()) {
+ auto val = sel.val->Value()->ValueAs<u32>().value;
+ max_sel_case = std::max(max_sel_case, val);
+ min_sel_case = std::min(min_sel_case, val);
+ }
+ }
+ }
+
+ // Our concern is around handling of signed range calculations (vs unsigned). Any
+ // range that gets close we will polyfill.
+ const uint32_t kSignedRangeLimit =
+ static_cast<uint32_t>(std::numeric_limits<int32_t>().max() - 1);
+ if ((max_sel_case - min_sel_case) >= kSignedRangeLimit) {
+ worklist.Push(s);
+ }
+ }
+ }
+
+ for (auto* s : worklist) {
+ auto* switch_cond = s->Condition();
+ Vector<core::ir::Value*, 4> conditions;
+ core::ir::Switch::Case* default_case = nullptr;
+ // We take the cases here because we're going to attach the case block to an `if`
+ // statement and the switch will now be replaced with a single default case block.
+ auto cases = s->TakeCases();
+ auto* def = b.DefaultCase(s);
+ b.Append(def, [&] {
+ for (auto& c : cases) {
+ // Default block is required by spec. It will need to be treated special.
+ // It is possible that default case will also have non default selectors.
+ // These additional selectors are superfluous as they will just form one
+ // default.
+ bool found_default = false;
+ for (auto& sel : c.selectors) {
+ if (sel.IsDefault()) {
+ default_case = &c;
+ found_default = true;
+ break;
+ }
+ }
+ if (found_default) {
+ continue;
+ }
+
+ core::ir::Value* case_cond = nullptr;
+ for (auto& sel : c.selectors) {
+ auto* curr_selector =
+ b.Equal(ty.bool_(), switch_cond, sel.val->As<core::ir::Value>());
+ if (case_cond) {
+ case_cond = b.Or(ty.bool_(), curr_selector, case_cond)->Result();
+ } else {
+ case_cond = curr_selector->Result();
+ }
+ }
+ conditions.Push(case_cond);
+ auto* if_cond = b.If(case_cond);
+ if_cond->SetTrue(c.block);
+ }
+
+ TINT_ASSERT(default_case);
+ // Special handling required for default case. All non-default cases have exited the
+ // switch by this point so the only possibility that remains is the default.
+ auto* if_cond = b.If(b.Constant(true));
+ if_cond->SetTrue(default_case->block);
+ });
+
+ b.Append(s->Cases()[0].block, [&] { b.Unreachable(); });
+ }
+ }
+};
+
+} // namespace
+
+Result<SuccessType> CaseSwitchToIfElse(core::ir::Module& ir) {
+ auto result =
+ ValidateAndDumpIfNeeded(ir, "spirv.CaseSwitchToIfElse", kCaseSwitchToIfElseCapabilities);
+ if (result != Success) {
+ return result;
+ }
+
+ State{ir}.Process();
+
+ return Success;
+}
+
+} // namespace tint::spirv::writer::raise
diff --git a/src/tint/lang/spirv/writer/raise/case_switch_to_if_else.h b/src/tint/lang/spirv/writer/raise/case_switch_to_if_else.h
new file mode 100644
index 0000000..ae846ff
--- /dev/null
+++ b/src/tint/lang/spirv/writer/raise/case_switch_to_if_else.h
@@ -0,0 +1,48 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_SPIRV_WRITER_RAISE_CASE_SWITCH_TO_IF_ELSE_H_
+#define SRC_TINT_LANG_SPIRV_WRITER_RAISE_CASE_SWITCH_TO_IF_ELSE_H_
+
+#include "src/tint/utils/result.h"
+
+// Forward declarations.
+namespace tint::core::ir {
+class Module;
+}
+
+namespace tint::spirv::writer::raise {
+
+/// CaseSwitchToIfElse is a transform that converts case switch statements to a series of if
+/// statements with the default being an emulated else condition.
+/// @param module the module to transform
+/// @returns success or failure
+Result<SuccessType> CaseSwitchToIfElse(core::ir::Module& module);
+
+} // namespace tint::spirv::writer::raise
+
+#endif // SRC_TINT_LANG_SPIRV_WRITER_RAISE_CASE_SWITCH_TO_IF_ELSE_H_
diff --git a/src/tint/lang/spirv/writer/raise/case_switch_to_if_else_test.cc b/src/tint/lang/spirv/writer/raise/case_switch_to_if_else_test.cc
new file mode 100644
index 0000000..738d69f
--- /dev/null
+++ b/src/tint/lang/spirv/writer/raise/case_switch_to_if_else_test.cc
@@ -0,0 +1,1375 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/spirv/writer/raise/case_switch_to_if_else.h"
+
+#include <utility>
+
+#include "src/tint/lang/core/ir/transform/helper_test.h"
+
+namespace tint::spirv::writer::raise {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+using SpirvWriter_CaseSwitchToIfElseTest = core::ir::transform::TransformTest;
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, BasicSwitch) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(-1_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.Case(s, {b.Constant(2_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (-1i, $B2), c: (2i, $B3), c: (default, $B4)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ $B4: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %3:bool = eq %param0, -1i
+ if %3 [t: $B3] { # if_1
+ $B3: { # true
+ ret
+ }
+ }
+ %4:bool = eq %param0, 2i
+ if %4 [t: $B4] { # if_2
+ $B4: { # true
+ ret
+ }
+ }
+ if true [t: $B5] { # if_3
+ $B5: { # true
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ unreachable
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, ReorderedBasicSwitch) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.Case(s, {b.Constant(2_i)}), [&] { //
+ b.Return(func);
+ });
+
+ b.Append(b.Case(s, {b.Constant(-1_i)}), [&] { //
+ b.Return(func);
+ });
+
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (default, $B2), c: (2i, $B3), c: (-1i, $B4)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ $B4: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %3:bool = eq %param0, 2i
+ if %3 [t: $B3] { # if_1
+ $B3: { # true
+ ret
+ }
+ }
+ %4:bool = eq %param0, -1i
+ if %4 [t: $B4] { # if_2
+ $B4: { # true
+ ret
+ }
+ }
+ if true [t: $B5] { # if_3
+ $B5: { # true
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ unreachable
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, SwitchOnlyDefault) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, SwitchWithReturn) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(-1_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.Case(s, {b.Constant(2_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (-1i, $B2), c: (2i, $B3), c: (default, $B4)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ $B4: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %3:bool = eq %param0, -1i
+ if %3 [t: $B3] { # if_1
+ $B3: { # true
+ ret
+ }
+ }
+ %4:bool = eq %param0, 2i
+ if %4 [t: $B4] { # if_2
+ $B4: { # true
+ ret
+ }
+ }
+ if true [t: $B5] { # if_3
+ $B5: { # true
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ unreachable
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, SwitchMultiSelectorCase) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(-1_i), b.Constant(2_i)}), [&] { b.Return(func); });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (-1i 2i, $B2), c: (default, $B3)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %3:bool = eq %param0, -1i
+ %4:bool = eq %param0, 2i
+ %5:bool = or %4, %3
+ if %5 [t: $B3] { # if_1
+ $B3: { # true
+ ret
+ }
+ }
+ if true [t: $B4] { # if_2
+ $B4: { # true
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ unreachable
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, NestedSwitch) {
+ auto* cond1 = b.FunctionParam("cond1", ty.i32());
+ auto* cond2 = b.FunctionParam("cond2", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond1, cond2});
+ b.Append(func->Block(), [&] {
+ auto* s1 = b.Switch(cond1);
+ b.Append(b.Case(s1, {b.Constant(1_i), b.Constant(-1_i)}), [&] {
+ auto* s2 = b.Switch(cond2);
+ b.Append(b.Case(s2, {b.Constant(-2_i), b.Constant(2_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.DefaultCase(s2), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+ b.Append(b.DefaultCase(s1), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%cond1:i32, %cond2:i32):void {
+ $B1: {
+ switch %cond1 [c: (1i- -1, $B2), c: (default, $B3)] { # switch_1
+ $B2: { # case
+ switch %cond2 [c: (-2i 2i, $B4), c: (default, $B5)] { # switch_2
+ $B4: { # case
+ ret
+ }
+ $B5: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+ $B3: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%cond1:i32, %cond2:i32):void {
+ $B1: {
+ switch %cond1 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %4:bool = eq %cond1, 1i
+ %5:bool = eq %cond1, -1i
+ %6:bool = or %5, %4
+ if %6 [t: $B3] { # if_1
+ $B3: { # true
+ switch %cond2 [c: (default, $B4)] { # switch_2
+ $B4: { # case
+ %7:bool = eq %cond2, -2i
+ %8:bool = eq %cond2, 2i
+ %9:bool = or %8, %7
+ if %9 [t: $B5] { # if_2
+ $B5: { # true
+ ret
+ }
+ }
+ if true [t: $B6] { # if_3
+ $B6: { # true
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ unreachable
+ }
+ }
+ if true [t: $B7] { # if_4
+ $B7: { # true
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ unreachable
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, MixedIfAndSwitch) {
+ auto* cond1 = b.FunctionParam("cond1", ty.i32());
+ auto* cond2 = b.FunctionParam("cond2", ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond1, cond2});
+
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond1);
+ b.Append(b.Case(s, {b.Constant(1_i), b.Constant(-1_i)}), [&] {
+ auto* ifelse = b.If(cond2);
+ b.Append(ifelse->True(), [&] { //
+ b.Return(func);
+ });
+ b.Append(ifelse->False(), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%cond1:i32, %cond2:bool):void {
+ $B1: {
+ switch %cond1 [c: (1i- -1, $B2), c: (default, $B3)] { # switch_1
+ $B2: { # case
+ if %cond2 [t: $B4, f: $B5] { # if_1
+ $B4: { # true
+ ret
+ }
+ $B5: { # false
+ ret
+ }
+ }
+ unreachable
+ }
+ $B3: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%cond1:i32, %cond2:bool):void {
+ $B1: {
+ switch %cond1 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %4:bool = eq %cond1, 1i
+ %5:bool = eq %cond1, -1i
+ %6:bool = or %5, %4
+ if %6 [t: $B3] { # if_1
+ $B3: { # true
+ if %cond2 [t: $B4, f: $B5] { # if_2
+ $B4: { # true
+ ret
+ }
+ $B5: { # false
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ if true [t: $B6] { # if_3
+ $B6: { # true
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ unreachable
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, BasicSwitch_VarAssignment) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+
+ b.Append(func->Block(), [&] {
+ auto* var = b.Var("v", b.Zero(ty.i32()));
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(1_i)}), [&] {
+ b.Store(var, 10_i);
+ b.ExitSwitch(s);
+ });
+ b.Append(b.Case(s, {b.Constant(-2_i)}), [&] {
+ b.Store(var, 20_i);
+ b.ExitSwitch(s);
+ });
+ b.Append(b.DefaultCase(s), [&] {
+ b.Store(var, 30_i);
+ b.ExitSwitch(s);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %param0 [c: (1i, $B2), c: (-2i, $B3), c: (default, $B4)] { # switch_1
+ $B2: { # case
+ store %v, 10i
+ exit_switch # switch_1
+ }
+ $B3: { # case
+ store %v, 20i
+ exit_switch # switch_1
+ }
+ $B4: { # case
+ store %v, 30i
+ exit_switch # switch_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %4:bool = eq %param0, 1i
+ if %4 [t: $B3] { # if_1
+ $B3: { # true
+ store %v, 10i
+ exit_switch # switch_1
+ }
+ }
+ %5:bool = eq %param0, -2i
+ if %5 [t: $B4] { # if_2
+ $B4: { # true
+ store %v, 20i
+ exit_switch # switch_1
+ }
+ }
+ if true [t: $B5] { # if_3
+ $B5: { # true
+ store %v, 30i
+ exit_switch # switch_1
+ }
+ }
+ unreachable
+ }
+ }
+ ret
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, SwitchMultiSelectorCase_VarAssignment) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+
+ b.Append(func->Block(), [&] {
+ auto* var = b.Var("v", b.Zero(ty.i32()));
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(1_i), b.Constant(-2_i)}), [&] {
+ b.Store(var, 10_i);
+ b.ExitSwitch(s);
+ });
+ b.Append(b.DefaultCase(s), [&] {
+ b.Store(var, 20_i);
+ b.ExitSwitch(s);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %param0 [c: (1i- -2, $B2), c: (default, $B3)] { # switch_1
+ $B2: { # case
+ store %v, 10i
+ exit_switch # switch_1
+ }
+ $B3: { # case
+ store %v, 20i
+ exit_switch # switch_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %4:bool = eq %param0, 1i
+ %5:bool = eq %param0, -2i
+ %6:bool = or %5, %4
+ if %6 [t: $B3] { # if_1
+ $B3: { # true
+ store %v, 10i
+ exit_switch # switch_1
+ }
+ }
+ if true [t: $B4] { # if_2
+ $B4: { # true
+ store %v, 20i
+ exit_switch # switch_1
+ }
+ }
+ unreachable
+ }
+ }
+ ret
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, NestedSwitch_VarAssignment) {
+ auto* cond1 = b.FunctionParam("cond1", ty.i32());
+ auto* cond2 = b.FunctionParam("cond2", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond1, cond2});
+ b.Append(func->Block(), [&] {
+ auto* var = b.Var("v", b.Zero(ty.i32()));
+ auto* s1 = b.Switch(cond1);
+ b.Append(b.Case(s1, {b.Constant(1_i), b.Constant(-1_i)}), [&] {
+ auto* s2 = b.Switch(cond2);
+ b.Append(b.Case(s2, {b.Constant(2_i), b.Constant(-2_i)}), [&] {
+ b.Store(var, 10_i);
+ b.ExitSwitch(s2);
+ });
+ b.Append(b.DefaultCase(s2), [&] {
+ b.Store(var, 20_i);
+ b.ExitSwitch(s2);
+ });
+ b.ExitSwitch(s1);
+ });
+ b.Append(b.DefaultCase(s1), [&] {
+ b.Store(var, 30_i);
+ b.ExitSwitch(s1);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%cond1:i32, %cond2:i32):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %cond1 [c: (1i- -1, $B2), c: (default, $B3)] { # switch_1
+ $B2: { # case
+ switch %cond2 [c: (2i- -2, $B4), c: (default, $B5)] { # switch_2
+ $B4: { # case
+ store %v, 10i
+ exit_switch # switch_2
+ }
+ $B5: { # case
+ store %v, 20i
+ exit_switch # switch_2
+ }
+ }
+ exit_switch # switch_1
+ }
+ $B3: { # case
+ store %v, 30i
+ exit_switch # switch_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%cond1:i32, %cond2:i32):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %cond1 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %5:bool = eq %cond1, 1i
+ %6:bool = eq %cond1, -1i
+ %7:bool = or %6, %5
+ if %7 [t: $B3] { # if_1
+ $B3: { # true
+ switch %cond2 [c: (default, $B4)] { # switch_2
+ $B4: { # case
+ %8:bool = eq %cond2, 2i
+ %9:bool = eq %cond2, -2i
+ %10:bool = or %9, %8
+ if %10 [t: $B5] { # if_2
+ $B5: { # true
+ store %v, 10i
+ exit_switch # switch_2
+ }
+ }
+ if true [t: $B6] { # if_3
+ $B6: { # true
+ store %v, 20i
+ exit_switch # switch_2
+ }
+ }
+ unreachable
+ }
+ }
+ exit_switch # switch_1
+ }
+ }
+ if true [t: $B7] { # if_4
+ $B7: { # true
+ store %v, 30i
+ exit_switch # switch_1
+ }
+ }
+ unreachable
+ }
+ }
+ ret
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, MixedIfAndSwitch_VarAssignment) {
+ auto* cond1 = b.FunctionParam("cond1", ty.i32());
+ auto* cond2 = b.FunctionParam("cond2", ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond1, cond2});
+
+ b.Append(func->Block(), [&] {
+ auto* var = b.Var("v", b.Zero(ty.i32()));
+ auto* s = b.Switch(cond1);
+ b.Append(b.Case(s, {b.Constant(1_i), b.Constant(-1_i)}), [&] {
+ auto* ifelse = b.If(cond2);
+ b.Append(ifelse->True(), [&] {
+ b.Store(var, 10_i);
+ b.ExitIf(ifelse);
+ });
+ b.Append(ifelse->False(), [&] {
+ b.Store(var, 20_i);
+ b.ExitIf(ifelse);
+ });
+ b.ExitSwitch(s);
+ });
+ b.Append(b.DefaultCase(s), [&] {
+ b.Store(var, 30_i);
+ b.ExitSwitch(s);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%cond1:i32, %cond2:bool):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %cond1 [c: (1i- -1, $B2), c: (default, $B3)] { # switch_1
+ $B2: { # case
+ if %cond2 [t: $B4, f: $B5] { # if_1
+ $B4: { # true
+ store %v, 10i
+ exit_if # if_1
+ }
+ $B5: { # false
+ store %v, 20i
+ exit_if # if_1
+ }
+ }
+ exit_switch # switch_1
+ }
+ $B3: { # case
+ store %v, 30i
+ exit_switch # switch_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%cond1:i32, %cond2:bool):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %cond1 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %5:bool = eq %cond1, 1i
+ %6:bool = eq %cond1, -1i
+ %7:bool = or %6, %5
+ if %7 [t: $B3] { # if_1
+ $B3: { # true
+ if %cond2 [t: $B4, f: $B5] { # if_2
+ $B4: { # true
+ store %v, 10i
+ exit_if # if_2
+ }
+ $B5: { # false
+ store %v, 20i
+ exit_if # if_2
+ }
+ }
+ exit_switch # switch_1
+ }
+ }
+ if true [t: $B6] { # if_3
+ $B6: { # true
+ store %v, 30i
+ exit_switch # switch_1
+ }
+ }
+ unreachable
+ }
+ }
+ ret
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, RangeSmall) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(1_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.Case(s, {b.Constant(2_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (1i, $B2), c: (2i, $B3), c: (default, $B4)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ $B4: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = src;
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, RangeSmallSigned) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(-1_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.Case(s, {b.Constant(-2_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (-1i, $B2), c: (-2i, $B3), c: (default, $B4)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ $B4: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(src, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, RangeSmallSigned2) {
+ auto* cond = b.FunctionParam("param0", ty.i32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(-2147483640_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.Case(s, {b.Constant(2147483640_i)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:i32):void {
+ $B1: {
+ switch %param0 [c: (-2147483640i, $B2), c: (2147483640i, $B3), c: (default, $B4)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ $B4: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(src, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, RangeSmallUnsigned) {
+ auto* cond = b.FunctionParam("param0", ty.u32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(2'000'000'000_u)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.Case(s, {b.Constant(2'000'000'001_u)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:u32):void {
+ $B1: {
+ switch %param0 [c: (2000000000u, $B2), c: (2000000001u, $B3), c: (default, $B4)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ $B4: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(src, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, RangeSmallUnsignedLarge) {
+ auto* cond = b.FunctionParam("param0", ty.u32());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond});
+ b.Append(func->Block(), [&] {
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(0_u)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.Case(s, {b.Constant(4000'000'000_u)}), [&] { //
+ b.Return(func);
+ });
+ b.Append(b.DefaultCase(s), [&] { //
+ b.Return(func);
+ });
+ b.Unreachable();
+ });
+
+ auto* src = R"(
+%foo = func(%param0:u32):void {
+ $B1: {
+ switch %param0 [c: (0u, $B2), c: (4000000000u, $B3), c: (default, $B4)] { # switch_1
+ $B2: { # case
+ ret
+ }
+ $B3: { # case
+ ret
+ }
+ $B4: { # case
+ ret
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ Run(CaseSwitchToIfElse);
+ auto* expect = R"(
+%foo = func(%param0:u32):void {
+ $B1: {
+ switch %param0 [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %3:bool = eq %param0, 0u
+ if %3 [t: $B3] { # if_1
+ $B3: { # true
+ ret
+ }
+ }
+ %4:bool = eq %param0, 4000000000u
+ if %4 [t: $B4] { # if_2
+ $B4: { # true
+ ret
+ }
+ }
+ if true [t: $B5] { # if_3
+ $B5: { # true
+ ret
+ }
+ }
+ unreachable
+ }
+ }
+ unreachable
+ }
+}
+)";
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, BreakInsideIfInCase) {
+ auto* cond = b.FunctionParam("cond", ty.i32());
+ auto* pred = b.FunctionParam("pred", ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond, pred});
+
+ b.Append(func->Block(), [&] {
+ auto* var = b.Var("v", b.Zero(ty.i32()));
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(1_i), b.Constant(-1_i)}), [&] {
+ b.Store(var, 10_i);
+ auto* ifelse = b.If(pred);
+ b.Append(ifelse->True(), [&] {
+ b.Store(var, 11_i);
+ b.ExitSwitch(s); // break
+ });
+ b.Store(var, 12_i);
+ b.ExitSwitch(s);
+ });
+ b.Append(b.DefaultCase(s), [&] {
+ b.Store(var, 20_i);
+ b.ExitSwitch(s);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%cond:i32, %pred:bool):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %cond [c: (1i- -1, $B2), c: (default, $B3)] { # switch_1
+ $B2: { # case
+ store %v, 10i
+ if %pred [t: $B4] { # if_1
+ $B4: { # true
+ store %v, 11i
+ exit_switch # switch_1
+ }
+ }
+ store %v, 12i
+ exit_switch # switch_1
+ }
+ $B3: { # case
+ store %v, 20i
+ exit_switch # switch_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%cond:i32, %pred:bool):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %cond [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %5:bool = eq %cond, 1i
+ %6:bool = eq %cond, -1i
+ %7:bool = or %6, %5
+ if %7 [t: $B3] { # if_1
+ $B3: { # true
+ store %v, 10i
+ if %pred [t: $B4] { # if_2
+ $B4: { # true
+ store %v, 11i
+ exit_switch # switch_1
+ }
+ }
+ store %v, 12i
+ exit_switch # switch_1
+ }
+ }
+ if true [t: $B5] { # if_3
+ $B5: { # true
+ store %v, 20i
+ exit_switch # switch_1
+ }
+ }
+ unreachable
+ }
+ }
+ ret
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+TEST_F(SpirvWriter_CaseSwitchToIfElseTest, MultipleBreaksInCase) {
+ auto* cond = b.FunctionParam("cond", ty.i32());
+ auto* pred1 = b.FunctionParam("pred1", ty.bool_());
+ auto* pred2 = b.FunctionParam("pred2", ty.bool_());
+ auto* func = b.Function("foo", ty.void_());
+ func->SetParams({cond, pred1, pred2});
+
+ b.Append(func->Block(), [&] {
+ auto* var = b.Var("v", b.Zero(ty.i32()));
+ auto* s = b.Switch(cond);
+ b.Append(b.Case(s, {b.Constant(1_i), b.Constant(-1_i)}), [&] {
+ b.Store(var, 10_i);
+ auto* if1 = b.If(pred1);
+ b.Append(if1->True(), [&] {
+ b.Store(var, 11_i);
+ b.ExitSwitch(s); // break
+ });
+
+ auto* if2 = b.If(pred2);
+ b.Append(if2->True(), [&] {
+ b.Store(var, 12_i);
+ b.ExitSwitch(s); // break
+ });
+ b.Store(var, 13_i);
+ b.ExitSwitch(s);
+ });
+ b.Append(b.DefaultCase(s), [&] {
+ b.Store(var, 20_i);
+ b.ExitSwitch(s);
+ });
+ b.Return(func);
+ });
+
+ auto* src = R"(
+%foo = func(%cond:i32, %pred1:bool, %pred2:bool):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %cond [c: (1i- -1, $B2), c: (default, $B3)] { # switch_1
+ $B2: { # case
+ store %v, 10i
+ if %pred1 [t: $B4] { # if_1
+ $B4: { # true
+ store %v, 11i
+ exit_switch # switch_1
+ }
+ }
+ if %pred2 [t: $B5] { # if_2
+ $B5: { # true
+ store %v, 12i
+ exit_switch # switch_1
+ }
+ }
+ store %v, 13i
+ exit_switch # switch_1
+ }
+ $B3: { # case
+ store %v, 20i
+ exit_switch # switch_1
+ }
+ }
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+%foo = func(%cond:i32, %pred1:bool, %pred2:bool):void {
+ $B1: {
+ %v:ptr<function, i32, read_write> = var 0i
+ switch %cond [c: (default, $B2)] { # switch_1
+ $B2: { # case
+ %6:bool = eq %cond, 1i
+ %7:bool = eq %cond, -1i
+ %8:bool = or %7, %6
+ if %8 [t: $B3] { # if_1
+ $B3: { # true
+ store %v, 10i
+ if %pred1 [t: $B4] { # if_2
+ $B4: { # true
+ store %v, 11i
+ exit_switch # switch_1
+ }
+ }
+ if %pred2 [t: $B5] { # if_3
+ $B5: { # true
+ store %v, 12i
+ exit_switch # switch_1
+ }
+ }
+ store %v, 13i
+ exit_switch # switch_1
+ }
+ }
+ if true [t: $B6] { # if_4
+ $B6: { # true
+ store %v, 20i
+ exit_switch # switch_1
+ }
+ }
+ unreachable
+ }
+ }
+ ret
+ }
+}
+)";
+
+ Run(CaseSwitchToIfElse);
+
+ EXPECT_EQ(expect, str());
+}
+
+} // namespace
+} // namespace tint::spirv::writer::raise
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index 1cbca64..3e4ef75 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -52,6 +52,7 @@
#include "src/tint/lang/core/type/f32.h"
#include "src/tint/lang/spirv/writer/common/option_helpers.h"
#include "src/tint/lang/spirv/writer/raise/builtin_polyfill.h"
+#include "src/tint/lang/spirv/writer/raise/case_switch_to_if_else.h"
#include "src/tint/lang/spirv/writer/raise/expand_implicit_splats.h"
#include "src/tint/lang/spirv/writer/raise/fork_explicit_layout_types.h"
#include "src/tint/lang/spirv/writer/raise/handle_matrix_arithmetic.h"
@@ -204,6 +205,9 @@
// kAllowAnyInputAttachmentIndexType required after ExpandImplicitSplats
RUN_TRANSFORM(raise::HandleMatrixArithmetic, module);
RUN_TRANSFORM(raise::MergeReturn, module);
+ if (options.polyfill_case_switch) {
+ RUN_TRANSFORM(raise::CaseSwitchToIfElse, module);
+ }
RUN_TRANSFORM(raise::RemoveUnreachableInLoopContinuing, module);
RUN_TRANSFORM(
raise::ShaderIO, module,