[tint] Use platform demote to helper where possible

This change will only enable platform demote to helper by default for
the vulkan backend.

CTS does test the discard feature:
webgpu:shader,execution,statement,discard:*

Design doc:
https://docs.google.com/document/d/1CEnS-99jspI-ghs_wOpZ8vMGhOW80wQxxTDly87bFN4/edit?usp=sharing&resourcekey=0-o02HubN1NPVz1ssuEyqIPQ


Bug: 42250787, 372714384
Change-Id: I0997e9fe8972c3400cc9e568aca88cbecd74a7dd
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/215314
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: Peter McNeeley <petermcneeley@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/dawn/native/Toggles.cpp b/src/dawn/native/Toggles.cpp
index 9d7cc3d..f56bb66 100644
--- a/src/dawn/native/Toggles.cpp
+++ b/src/dawn/native/Toggles.cpp
@@ -207,6 +207,10 @@
      {"disable_workgroup_init",
       "Disables the workgroup memory zero-initialization for compute shaders.",
       "https://crbug.com/tint/1003", ToggleStage::Device}},
+    {Toggle::VulkanUseDemoteToHelperInvocationExtension,
+     {"vulkan_use_demote_to_helper_invocation_extension",
+      "Sets the use of the vulkan demote to helper extension", "https://crbug.com/42250787",
+      ToggleStage::Device}},
     {Toggle::DisableSymbolRenaming,
      {"disable_symbol_renaming", "Disables the WGSL symbol renaming so that names are preserved.",
       "https://crbug.com/dawn/1016", ToggleStage::Device}},
diff --git a/src/dawn/native/Toggles.h b/src/dawn/native/Toggles.h
index bed8bc6..38d8a3e 100644
--- a/src/dawn/native/Toggles.h
+++ b/src/dawn/native/Toggles.h
@@ -72,6 +72,7 @@
     DisallowSpirv,
     DumpShaders,
     DisableWorkgroupInit,
+    VulkanUseDemoteToHelperInvocationExtension,
     DisableSymbolRenaming,
     UseUserDefinedLabelsInBackend,
     UsePlaceholderFragmentInVertexOnlyPipeline,
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index 7d0e740..5107670 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -446,6 +446,12 @@
         featuresChain.Add(&usedKnobs.zeroInitializeWorkgroupMemoryFeatures);
     }
 
+    if (mDeviceInfo.HasExt(DeviceExt::DemoteToHelperInvocation)) {
+        DAWN_ASSERT(usedKnobs.HasExt(DeviceExt::DemoteToHelperInvocation));
+        usedKnobs.demoteToHelperInvocationFeatures = mDeviceInfo.demoteToHelperInvocationFeatures;
+        featuresChain.Add(&usedKnobs.demoteToHelperInvocationFeatures);
+    }
+
     if (mDeviceInfo.HasExt(DeviceExt::ShaderIntegerDotProduct)) {
         DAWN_ASSERT(usedKnobs.HasExt(DeviceExt::ShaderIntegerDotProduct));
 
diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
index 7e3702f..8109787 100644
--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
+++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
@@ -760,6 +760,19 @@
     // extension VK_KHR_zero_initialize_workgroup_memory.
     deviceToggles->Default(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension, true);
 
+    // Spirv OpKill does not do demote to helper and has also been deprecated. Use
+    // OpDemoteToHelperInvocation where the extension is available to get correct platform demote to
+    // helper for "discard".
+    if (!GetDeviceInfo().HasExt(DeviceExt::DemoteToHelperInvocation) ||
+        GetDeviceInfo().demoteToHelperInvocationFeatures.shaderDemoteToHelperInvocation ==
+            VK_FALSE) {
+        deviceToggles->ForceSet(Toggle::VulkanUseDemoteToHelperInvocationExtension, false);
+    }
+
+    // By default we will use the vulkan demote to helper extension if it is available. This gives
+    // us correct fragment shader discard semantics.
+    deviceToggles->Default(Toggle::VulkanUseDemoteToHelperInvocationExtension, true);
+
     // The environment can only request to use StorageInputOutput16 when the capability is
     // available.
     if (GetDeviceInfo()._16BitStorageFeatures.storageInputOutput16 == VK_FALSE) {
diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp
index 6aeacac..1358335 100644
--- a/src/dawn/native/vulkan/ShaderModuleVk.cpp
+++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp
@@ -346,8 +346,14 @@
     req.tintOptions.clamp_frag_depth = clampFragDepth;
     req.tintOptions.disable_robustness = !GetDevice()->IsRobustnessEnabled();
     req.tintOptions.emit_vertex_point_size = emitPointSize;
+
     req.tintOptions.disable_workgroup_init =
         GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
+    // The only possible alternative for the vulkan demote to helper extension is
+    // "OpTerminateInvocation" which remains unimplemented in dawn/tint.
+    req.tintOptions.use_demote_to_helper_invocation_extensions =
+        GetDevice()->IsToggleEnabled(Toggle::VulkanUseDemoteToHelperInvocationExtension);
+
     req.tintOptions.use_zero_initialize_workgroup_memory_extension =
         GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
     req.tintOptions.use_storage_input_output_16 =
diff --git a/src/dawn/native/vulkan/VulkanExtensions.cpp b/src/dawn/native/vulkan/VulkanExtensions.cpp
index a41f882..af222c3 100644
--- a/src/dawn/native/vulkan/VulkanExtensions.cpp
+++ b/src/dawn/native/vulkan/VulkanExtensions.cpp
@@ -174,6 +174,8 @@
     {DeviceExt::ShaderIntegerDotProduct, "VK_KHR_shader_integer_dot_product", VulkanVersion_1_3},
     {DeviceExt::ZeroInitializeWorkgroupMemory, "VK_KHR_zero_initialize_workgroup_memory",
      VulkanVersion_1_3},
+    {DeviceExt::DemoteToHelperInvocation, "VK_EXT_shader_demote_to_helper_invocation",
+     VulkanVersion_1_3},
     {DeviceExt::Maintenance4, "VK_KHR_maintenance4", VulkanVersion_1_3},
     {DeviceExt::SubgroupSizeControl, "VK_EXT_subgroup_size_control", VulkanVersion_1_3},
 
@@ -284,6 +286,7 @@
             case DeviceExt::DepthClipEnable:
             case DeviceExt::ShaderIntegerDotProduct:
             case DeviceExt::ZeroInitializeWorkgroupMemory:
+            case DeviceExt::DemoteToHelperInvocation:
             case DeviceExt::Maintenance4:
             case DeviceExt::Robustness2:
             case DeviceExt::SubgroupSizeControl:
diff --git a/src/dawn/native/vulkan/VulkanExtensions.h b/src/dawn/native/vulkan/VulkanExtensions.h
index c536ff8..9e4756c 100644
--- a/src/dawn/native/vulkan/VulkanExtensions.h
+++ b/src/dawn/native/vulkan/VulkanExtensions.h
@@ -112,6 +112,7 @@
     // Promoted to 1.3
     ShaderIntegerDotProduct,
     ZeroInitializeWorkgroupMemory,
+    DemoteToHelperInvocation,
     Maintenance4,
     SubgroupSizeControl,
 
diff --git a/src/dawn/native/vulkan/VulkanInfo.cpp b/src/dawn/native/vulkan/VulkanInfo.cpp
index b612269..ea78ea4 100644
--- a/src/dawn/native/vulkan/VulkanInfo.cpp
+++ b/src/dawn/native/vulkan/VulkanInfo.cpp
@@ -304,6 +304,12 @@
                 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ZERO_INITIALIZE_WORKGROUP_MEMORY_FEATURES);
         }
 
+        if (info.extensions[DeviceExt::DemoteToHelperInvocation]) {
+            featuresChain.Add(
+                &info.demoteToHelperInvocationFeatures,
+                VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_DEMOTE_TO_HELPER_INVOCATION_FEATURES_EXT);
+        }
+
         if (info.extensions[DeviceExt::Robustness2]) {
             featuresChain.Add(&info.robustness2Features,
                               VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ROBUSTNESS_2_FEATURES_EXT);
diff --git a/src/dawn/native/vulkan/VulkanInfo.h b/src/dawn/native/vulkan/VulkanInfo.h
index 971542c..914ecf5 100644
--- a/src/dawn/native/vulkan/VulkanInfo.h
+++ b/src/dawn/native/vulkan/VulkanInfo.h
@@ -65,6 +65,7 @@
     VkPhysicalDevice16BitStorageFeaturesKHR _16BitStorageFeatures;
     VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroupSizeControlFeatures;
     VkPhysicalDeviceZeroInitializeWorkgroupMemoryFeaturesKHR zeroInitializeWorkgroupMemoryFeatures;
+    VkPhysicalDeviceShaderDemoteToHelperInvocationFeaturesEXT demoteToHelperInvocationFeatures;
     VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shaderIntegerDotProductFeatures;
     VkPhysicalDeviceDepthClipEnableFeaturesEXT depthClipEnableFeatures;
     VkPhysicalDeviceRobustness2FeaturesEXT robustness2Features;
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index b88aa55..2e26a8f 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -614,6 +614,7 @@
     "end2end/DeviceInitializationTests.cpp",
     "end2end/DeviceLifetimeTests.cpp",
     "end2end/DeviceLostTests.cpp",
+    "end2end/DiscardBasicTests.cpp",
     "end2end/DrawIndexedIndirectTests.cpp",
     "end2end/DrawIndexedTests.cpp",
     "end2end/DrawIndirectTests.cpp",
diff --git a/src/dawn/tests/end2end/DiscardBasicTests.cpp b/src/dawn/tests/end2end/DiscardBasicTests.cpp
new file mode 100644
index 0000000..95fbcb6
--- /dev/null
+++ b/src/dawn/tests/end2end/DiscardBasicTests.cpp
@@ -0,0 +1,132 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "dawn/tests/DawnTest.h"
+
+#include "dawn/utils/ComboRenderPipelineDescriptor.h"
+#include "dawn/utils/WGPUHelpers.h"
+
+// This test is just to make sure we still have proper derivatives after discarding
+// fragments in the same quad. A proper implementation of discard is to demote to
+// helper to still allow helper invocations participate in derivative computations.
+
+namespace dawn {
+namespace {
+
+// Render target is the size of a quad.
+constexpr uint32_t kRTSize = 2;
+
+class DiscardBasicTest : public DawnTest {
+  protected:
+    void SetUp() override {
+        DawnTest::SetUp();
+
+        renderPass = utils::CreateBasicRenderPass(device, kRTSize, kRTSize);
+
+        wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
+            @vertex
+            fn main(@location(0) pos : vec4f) -> @builtin(position) vec4f {
+                return pos;
+            })");
+
+        wgpu::ShaderModule fsModule = utils::CreateShaderModule(device, R"(
+            @fragment fn main(@builtin(position) FragCoord : vec4f) -> @location(0) vec4f {
+                // Discard all fragments of quad except one.
+                if(FragCoord.x < 0.7 || FragCoord.y < 0.7){
+                    discard;
+                }
+                // Note: This value is computed after the discard.
+                var post_val:vec2f = (FragCoord.xy * vec2f(0.25));
+                return vec4f(FragCoord.x, FragCoord.y, dpdx(post_val).x, 1.0);
+            })");
+
+        utils::ComboRenderPipelineDescriptor descriptor;
+        descriptor.vertex.module = vsModule;
+        descriptor.cFragment.module = fsModule;
+        descriptor.primitive.topology = wgpu::PrimitiveTopology::TriangleList;
+        descriptor.vertex.bufferCount = 1;
+        descriptor.cBuffers[0].arrayStride = 4 * sizeof(float);
+        descriptor.cBuffers[0].attributeCount = 1;
+        descriptor.cAttributes[0].format = wgpu::VertexFormat::Float32x4;
+        descriptor.cTargets[0].format = renderPass.colorFormat;
+
+        pipeline = device.CreateRenderPipeline(&descriptor);
+
+        vertexBuffer = utils::CreateBufferFromData<float>(
+            device, wgpu::BufferUsage::Vertex,
+            {// The bottom left triangle
+             -1.0f, 1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, -1.0f, -1.0f, 0.0f, 1.0f,
+
+             // The top right triangle
+             -1.0f, 1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f});
+    }
+
+    utils::BasicRenderPass renderPass;
+    wgpu::RenderPipeline pipeline;
+    wgpu::Buffer vertexBuffer;
+
+    void Test(uint32_t vertexCount,
+              uint32_t instanceCount,
+              uint32_t firstIndex,
+              uint32_t firstInstance,
+              utils::RGBA8 bottomLeftExpected,
+              utils::RGBA8 topRightExpected) {
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        {
+            wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
+            pass.SetPipeline(pipeline);
+            pass.SetVertexBuffer(0, vertexBuffer);
+            pass.Draw(vertexCount, instanceCount, firstIndex, firstInstance);
+            pass.End();
+        }
+
+        wgpu::CommandBuffer commands = encoder.Finish();
+        queue.Submit(1, &commands);
+
+        EXPECT_PIXEL_RGBA8_EQ(bottomLeftExpected, renderPass.color, 0, 0);
+        EXPECT_PIXEL_RGBA8_EQ(topRightExpected, renderPass.color, 1, 1);
+    }
+};
+
+// The basic triangle draw.
+TEST_P(DiscardBasicTest, DiscardWithDerivative) {
+    utils::RGBA8 filled(255, 255, 64, 255);
+    utils::RGBA8 discarded(0, 0, 0, 0);
+
+    Test(6, 1, 0, 0, discarded, filled);
+}
+
+DAWN_INSTANTIATE_TEST(DiscardBasicTest,
+                      D3D11Backend(),
+                      D3D12Backend(),
+                      MetalBackend(),
+                      OpenGLBackend(),
+                      OpenGLESBackend(),
+                      VulkanBackend());
+
+}  // anonymous namespace
+}  // namespace dawn
diff --git a/src/tint/cmd/tint/main.cc b/src/tint/cmd/tint/main.cc
index 48483cb..121248e 100644
--- a/src/tint/cmd/tint/main.cc
+++ b/src/tint/cmd/tint/main.cc
@@ -181,6 +181,7 @@
     bool verbose = false;
     bool parse_only = false;
     bool disable_workgroup_init = false;
+    bool disable_demote_to_helper = false;
     bool validate = false;
     bool print_hash = false;
     bool dump_inspector_bindings = false;
@@ -398,6 +399,10 @@
         "disable-workgroup-init", "Disable workgroup memory zero initialization", Default{false});
     TINT_DEFER(opts->disable_workgroup_init = *disable_wg_init.value);
 
+    auto& disable_demote_to_helper = options.Add<BoolOption>(
+        "disable-demote-to-helper", "Disable demote to helper for discard", Default{false});
+    TINT_DEFER(opts->disable_demote_to_helper = *disable_demote_to_helper.value);
+
     auto& rename_all = options.Add<BoolOption>("rename-all", "Renames all symbols", Default{false});
     TINT_DEFER(opts->rename_all = *rename_all.value);
 
diff --git a/src/tint/lang/spirv/writer/common/options.h b/src/tint/lang/spirv/writer/common/options.h
index 6d7fd7f..67dddf6 100644
--- a/src/tint/lang/spirv/writer/common/options.h
+++ b/src/tint/lang/spirv/writer/common/options.h
@@ -150,6 +150,9 @@
     /// Set to `true` to disable workgroup memory zero initialization
     bool disable_workgroup_init = false;
 
+    /// Set to `true` to allow for the usage of the demote to helper extension.
+    bool use_demote_to_helper_invocation_extensions = false;
+
     /// Set to `true` to initialize workgroup memory with OpConstantNull when
     /// VK_KHR_zero_initialize_workgroup_memory is enabled.
     bool use_zero_initialize_workgroup_memory_extension = false;
@@ -189,6 +192,7 @@
                  disable_image_robustness,
                  disable_runtime_sized_array_index_clamping,
                  disable_workgroup_init,
+                 use_demote_to_helper_invocation_extensions,
                  use_zero_initialize_workgroup_memory_extension,
                  use_storage_input_output_16,
                  emit_vertex_point_size,
diff --git a/src/tint/lang/spirv/writer/discard_test.cc b/src/tint/lang/spirv/writer/discard_test.cc
index 2a55213..febaf8f 100644
--- a/src/tint/lang/spirv/writer/discard_test.cc
+++ b/src/tint/lang/spirv/writer/discard_test.cc
@@ -141,5 +141,49 @@
 )");
 }
 
+TEST_F(SpirvWriterTest, Discard_DemoteToHelperWithExtension) {
+    Options opts{};
+    opts.use_demote_to_helper_invocation_extensions = true;
+
+    auto* v = b.Var("v", ty.ptr<private_, i32>());
+    v->SetInitializer(b.Constant(42_i));
+    mod.root_block->Append(v);
+
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Discard();
+        auto* load = b.Load(v);
+        auto* add = b.Add(ty.i32(), load, 1_i);
+
+        b.Store(v, add);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate(opts)) << Error() << output_;
+    EXPECT_INST("OpDemoteToHelperInvocation");
+}
+
+TEST_F(SpirvWriterTest, Discard_DemoteToHelperAsTransform) {
+    Options opts{};
+    opts.use_demote_to_helper_invocation_extensions = false;
+
+    auto* v = b.Var("v", ty.ptr<private_, i32>());
+    v->SetInitializer(b.Constant(42_i));
+    mod.root_block->Append(v);
+
+    auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+    b.Append(func->Block(), [&] {
+        b.Discard();
+        auto* load = b.Load(v);
+        auto* add = b.Add(ty.i32(), load, 1_i);
+
+        b.Store(v, add);
+        b.Return(func);
+    });
+
+    ASSERT_TRUE(Generate(opts)) << Error() << output_;
+    EXPECT_INST("continue_execution");
+}
+
 }  // namespace
 }  // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index c9ae32b..0a97a88 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -48,6 +48,7 @@
 #include "src/tint/lang/core/ir/continue.h"
 #include "src/tint/lang/core/ir/convert.h"
 #include "src/tint/lang/core/ir/core_builtin_call.h"
+#include "src/tint/lang/core/ir/discard.h"
 #include "src/tint/lang/core/ir/exit_if.h"
 #include "src/tint/lang/core/ir/exit_loop.h"
 #include "src/tint/lang/core/ir/exit_switch.h"
@@ -922,6 +923,7 @@
                 [&](core::ir::Let* l) { EmitLet(l); },                                //
                 [&](core::ir::If* i) { EmitIf(i); },                                  //
                 [&](core::ir::Terminator* t) { EmitTerminator(t); },                  //
+                [&](core::ir::Discard* t) { EmitDiscard(t); },                        //
                 TINT_ICE_ON_NO_MATCH);
 
             // Set the name for the SPIR-V result ID if provided in the module.
@@ -933,6 +935,19 @@
         }
     }
 
+    void EmitDiscard(core::ir::Discard*) {
+        if (options_.use_demote_to_helper_invocation_extensions) {
+            module_.PushExtension("SPV_EXT_demote_to_helper_invocation");
+            module_.PushCapability(SpvCapabilityDemoteToHelperInvocationEXT);
+            current_function_.PushInst(spv::Op::OpDemoteToHelperInvocationEXT, {});
+        } else {
+            // OpKill does not have the same behavioral semantics as demote to helper and will not
+            // be conformant. OpKill has also been deprecated and the alternative
+            // OpTerminateInvocation also does not have demote to helper semantics.
+            TINT_ICE() << "No substitute function for discard";
+        }
+    }
+
     /// Emit a terminator instruction.
     /// @param t the terminator instruction to emit
     void EmitTerminator(core::ir::Terminator* t) {
diff --git a/src/tint/lang/spirv/writer/raise/raise.cc b/src/tint/lang/spirv/writer/raise/raise.cc
index 9395c87..e3b1231 100644
--- a/src/tint/lang/spirv/writer/raise/raise.cc
+++ b/src/tint/lang/spirv/writer/raise/raise.cc
@@ -136,8 +136,10 @@
     // produce pointers to matrices.
     RUN_TRANSFORM(core::ir::transform::CombineAccessInstructions, module);
 
-    // DemoteToHelper must come before any transform that introduces non-core instructions.
-    RUN_TRANSFORM(core::ir::transform::DemoteToHelper, module);
+    if (!options.use_demote_to_helper_invocation_extensions) {
+        // DemoteToHelper must come before any transform that introduces non-core instructions.
+        RUN_TRANSFORM(core::ir::transform::DemoteToHelper, module);
+    }
 
     RUN_TRANSFORM(raise::BuiltinPolyfill, module, options.use_vulkan_memory_model);
     RUN_TRANSFORM(raise::ExpandImplicitSplats, module);