[hlsl-writer] Add support for subgroup inputs

Use CanonicalizeEntryPointIO to map these to HLSL's wave intrinsic
functions.

Bug: tint:2000
Change-Id: I943b4d8414a909c033386323eb12b2f8f21a5740
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144041
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
index 88fd9c6..c0be95f 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
@@ -1133,6 +1133,19 @@
         }
     }
 
+    if (auto* wave_intrinsic =
+            ast::GetAttribute<ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic>(
+                func->Declaration()->attributes)) {
+        switch (wave_intrinsic->op) {
+            case ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic::Op::kWaveGetLaneCount:
+                out << "WaveGetLaneCount()";
+                return true;
+            case ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic::Op::kWaveGetLaneIndex:
+                out << "WaveGetLaneIndex()";
+                return true;
+        }
+    }
+
     out << func->Declaration()->name->symbol.Name() << "(";
 
     bool first = true;
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
index af0b2f0..457516f 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
@@ -32,6 +32,7 @@
 
 TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO);
 TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO::Config);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic);
 
 namespace tint::ast::transform {
 
@@ -142,6 +143,8 @@
     std::unordered_set<std::string> input_names;
     /// A map of cloned attribute to builtin value
     Hashmap<const BuiltinAttribute*, core::BuiltinValue, 16> builtin_attrs;
+    /// A map of builtin values to HLSL wave intrinsic functions.
+    Hashmap<core::BuiltinValue, Symbol, 2> wave_intrinsics;
 
     /// Constructor
     /// @param context the clone context
@@ -224,6 +227,44 @@
         return wrapper_struct_param_name;
     }
 
+    /// Call a wave intrinsic function for the subgroup builtin contained in @p attrs, if any.
+    /// @param attrs the attribute list that may contain a subgroup builtin
+    /// @returns an expression that calls a HLSL wave intrinsic, or nullptr
+    const ast::CallExpression* CallWaveIntrinsic(VectorRef<const Attribute*> attrs) {
+        if (cfg.shader_style != ShaderStyle::kHlsl) {
+            return nullptr;
+        }
+
+        // Helper to make a wave intrinsic.
+        auto make_intrinsic = [&](const char* name, HLSLWaveIntrinsic::Op op) {
+            auto symbol = b.Symbols().New(name);
+            b.Func(symbol, Empty, b.ty.u32(), nullptr,
+                   Vector{b.ASTNodes().Create<HLSLWaveIntrinsic>(b.ID(), b.AllocateNodeID(), op),
+                          b.Disable(DisabledValidation::kFunctionHasNoBody)});
+            return symbol;
+        };
+
+        // Get or create the intrinsic function.
+        auto builtin = BuiltinOf(attrs);
+        auto intrinsic = wave_intrinsics.GetOrCreate(builtin, [&] {
+            if (builtin == core::BuiltinValue::kSubgroupInvocationId) {
+                return make_intrinsic("__WaveGetLaneIndex",
+                                      HLSLWaveIntrinsic::Op::kWaveGetLaneIndex);
+            }
+            if (builtin == core::BuiltinValue::kSubgroupSize) {
+                return make_intrinsic("__WaveGetLaneCount",
+                                      HLSLWaveIntrinsic::Op::kWaveGetLaneCount);
+            }
+            return Symbol();
+        });
+        if (!intrinsic) {
+            return nullptr;
+        }
+
+        // Call the intrinsic function.
+        return b.Call(intrinsic);
+    }
+
     /// Add a shader input to the entry point.
     /// @param name the name of the shader input
     /// @param type the type of the shader input
@@ -349,6 +390,14 @@
     /// that will be passed to the original function.
     /// @param param the original function parameter
     void ProcessNonStructParameter(const sem::Parameter* param) {
+        if (auto* wave_intrinsic_call = CallWaveIntrinsic(param->Declaration()->attributes)) {
+            inner_call_parameters.Push(wave_intrinsic_call);
+            for (auto* attr : param->Declaration()->attributes) {
+                ctx.Remove(param->Declaration()->attributes, attr);
+            }
+            return;
+        }
+
         // Do not add interpolation attributes on vertex input
         bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
         // Remove the shader IO attributes from the inner function parameter, and attach them to the
@@ -388,6 +437,11 @@
                 continue;
             }
 
+            if (auto* wave_intrinsic_call = CallWaveIntrinsic(member->Declaration()->attributes)) {
+                inner_struct_values.Push(wave_intrinsic_call);
+                continue;
+            }
+
             auto name = member->Name().Name();
 
             auto attributes =
@@ -877,4 +931,24 @@
 CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
 CanonicalizeEntryPointIO::Config::~Config() = default;
 
+CanonicalizeEntryPointIO::HLSLWaveIntrinsic::HLSLWaveIntrinsic(GenerationID pid, NodeID nid, Op o)
+    : Base(pid, nid, Empty), op(o) {}
+CanonicalizeEntryPointIO::HLSLWaveIntrinsic::~HLSLWaveIntrinsic() = default;
+std::string CanonicalizeEntryPointIO::HLSLWaveIntrinsic::InternalName() const {
+    StringStream ss;
+    switch (op) {
+        case Op::kWaveGetLaneCount:
+            return "intrinsic_wave_get_lane_count";
+        case Op::kWaveGetLaneIndex:
+            return "intrinsic_wave_get_lane_index";
+    }
+    return ss.str();
+}
+
+const CanonicalizeEntryPointIO::HLSLWaveIntrinsic*
+CanonicalizeEntryPointIO::HLSLWaveIntrinsic::Clone(ast::CloneContext& ctx) const {
+    return ctx.dst->ASTNodes().Create<CanonicalizeEntryPointIO::HLSLWaveIntrinsic>(
+        ctx.dst->ID(), ctx.dst->AllocateNodeID(), op);
+}
+
 }  // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
index d543de0..a6b3715 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
@@ -15,6 +15,9 @@
 #ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
 #define SRC_TINT_LANG_WGSL_AST_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
 
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
 #include "src/tint/lang/wgsl/ast/transform/transform.h"
 
 namespace tint::ast::transform {
@@ -123,6 +126,36 @@
         const bool emit_vertex_point_size;
     };
 
+    /// HLSLWaveIntrinsic is an InternalAttribute that is used to decorate a stub function so that
+    /// the HLSL backend transforms this into calls to Wave* intrinsic functions.
+    class HLSLWaveIntrinsic final : public Castable<HLSLWaveIntrinsic, InternalAttribute> {
+      public:
+        /// Wave intrinsic op
+        enum class Op {
+            kWaveGetLaneIndex,
+            kWaveGetLaneCount,
+        };
+
+        /// Constructor
+        /// @param pid the identifier of the program that owns this node
+        /// @param nid the unique node identifier
+        /// @param o the op of the wave intrinsic
+        HLSLWaveIntrinsic(GenerationID pid, NodeID nid, Op o);
+        /// Destructor
+        ~HLSLWaveIntrinsic() override;
+
+        /// @copydoc InternalAttribute::InternalName
+        std::string InternalName() const override;
+
+        /// Performs a deep clone of this object using the program::CloneContext `ctx`.
+        /// @param ctx the clone context
+        /// @return the newly cloned object
+        const HLSLWaveIntrinsic* Clone(CloneContext& ctx) const override;
+
+        /// The op of the intrinsic
+        const Op op;
+    };
+
     /// Constructor
     CanonicalizeEntryPointIO();
     ~CanonicalizeEntryPointIO() override;
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
index b924acd..717f39b 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
@@ -4171,5 +4171,88 @@
     EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(CanonicalizeEntryPointIOTest, SubgroupBuiltins_Hlsl) {
+    auto* src = R"(
+enable chromium_experimental_subgroups;
+
+@compute @workgroup_size(64)
+fn frag_main(@builtin(subgroup_invocation_id) id : u32,
+             @builtin(subgroup_size) size : u32) {
+  let x = size - id;
+}
+)";
+
+    auto* expect = R"(
+enable chromium_experimental_subgroups;
+
+@internal(intrinsic_wave_get_lane_index) @internal(disable_validation__function_has_no_body)
+fn __WaveGetLaneIndex() -> u32
+
+@internal(intrinsic_wave_get_lane_count) @internal(disable_validation__function_has_no_body)
+fn __WaveGetLaneCount() -> u32
+
+fn frag_main_inner(id : u32, size : u32) {
+  let x = (size - id);
+}
+
+@compute @workgroup_size(64)
+fn frag_main() {
+  frag_main_inner(__WaveGetLaneIndex(), __WaveGetLaneCount());
+}
+)";
+
+    DataMap data;
+    data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+    auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, SubgroupBuiltinsStruct_Hlsl) {
+    auto* src = R"(
+enable chromium_experimental_subgroups;
+
+struct Inputs {
+  @builtin(subgroup_invocation_id) id : u32,
+  @builtin(subgroup_size) size : u32,
+}
+
+@compute @workgroup_size(64)
+fn frag_main(inputs : Inputs) {
+  let x = inputs.size - inputs.id;
+}
+)";
+
+    auto* expect = R"(
+enable chromium_experimental_subgroups;
+
+@internal(intrinsic_wave_get_lane_index) @internal(disable_validation__function_has_no_body)
+fn __WaveGetLaneIndex() -> u32
+
+@internal(intrinsic_wave_get_lane_count) @internal(disable_validation__function_has_no_body)
+fn __WaveGetLaneCount() -> u32
+
+struct Inputs {
+  id : u32,
+  size : u32,
+}
+
+fn frag_main_inner(inputs : Inputs) {
+  let x = (inputs.size - inputs.id);
+}
+
+@compute @workgroup_size(64)
+fn frag_main() {
+  frag_main_inner(Inputs(__WaveGetLaneIndex(), __WaveGetLaneCount()));
+}
+)";
+
+    DataMap data;
+    data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+    auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+    EXPECT_EQ(expect, str(got));
+}
+
 }  // namespace
 }  // namespace tint::ast::transform
diff --git a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.dxc.hlsl b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.dxc.hlsl
index c4e9adb..9cf0831 100644
--- a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.dxc.hlsl
+++ b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.dxc.hlsl
@@ -1,16 +1,11 @@
-SKIP: FAILED
+RWByteAddressBuffer output : register(u0);
 
-
-enable chromium_experimental_subgroups;
-
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
-
-@compute @workgroup_size(1)
-fn main(@builtin(subgroup_invocation_id) subgroup_invocation_id : u32, @builtin(subgroup_size) subgroup_size : u32) {
-  output[subgroup_invocation_id] = subgroup_size;
+void main_inner(uint subgroup_invocation_id, uint subgroup_size) {
+  output.Store((4u * subgroup_invocation_id), asuint(subgroup_size));
 }
 
-Failed to generate: shader_io/compute_subgroup_builtins.wgsl:1:8 error: HLSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
-       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
+[numthreads(1, 1, 1)]
+void main() {
+  main_inner(WaveGetLaneIndex(), WaveGetLaneCount());
+  return;
+}
diff --git a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.fxc.hlsl b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.fxc.hlsl
index c4e9adb..0bbfaea 100644
--- a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.fxc.hlsl
+++ b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.fxc.hlsl
@@ -1,16 +1,13 @@
 SKIP: FAILED
 
+RWByteAddressBuffer output : register(u0);
 
-enable chromium_experimental_subgroups;
-
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
-
-@compute @workgroup_size(1)
-fn main(@builtin(subgroup_invocation_id) subgroup_invocation_id : u32, @builtin(subgroup_size) subgroup_size : u32) {
-  output[subgroup_invocation_id] = subgroup_size;
+void main_inner(uint subgroup_invocation_id, uint subgroup_size) {
+  output.Store((4u * subgroup_invocation_id), asuint(subgroup_size));
 }
 
-Failed to generate: shader_io/compute_subgroup_builtins.wgsl:1:8 error: HLSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
-       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
+[numthreads(1, 1, 1)]
+void main() {
+  main_inner(WaveGetLaneIndex(), WaveGetLaneCount());
+  return;
+}
diff --git a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.dxc.hlsl b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.dxc.hlsl
index 952abb3..8e1d0db 100644
--- a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.dxc.hlsl
+++ b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.dxc.hlsl
@@ -1,23 +1,17 @@
-SKIP: FAILED
-
-
-enable chromium_experimental_subgroups;
-
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
+RWByteAddressBuffer output : register(u0);
 
 struct ComputeInputs {
-  @builtin(subgroup_invocation_id)
-  subgroup_invocation_id : u32,
-  @builtin(subgroup_size)
-  subgroup_size : u32,
+  uint subgroup_invocation_id;
+  uint subgroup_size;
+};
+
+void main_inner(ComputeInputs inputs) {
+  output.Store((4u * inputs.subgroup_invocation_id), asuint(inputs.subgroup_size));
 }
 
-@compute @workgroup_size(1)
-fn main(inputs : ComputeInputs) {
-  output[inputs.subgroup_invocation_id] = inputs.subgroup_size;
+[numthreads(1, 1, 1)]
+void main() {
+  const ComputeInputs tint_symbol = {WaveGetLaneIndex(), WaveGetLaneCount()};
+  main_inner(tint_symbol);
+  return;
 }
-
-Failed to generate: shader_io/compute_subgroup_builtins_struct.wgsl:1:8 error: HLSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
-       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
diff --git a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.fxc.hlsl b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.fxc.hlsl
index 952abb3..830f68f 100644
--- a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.fxc.hlsl
+++ b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.fxc.hlsl
@@ -1,23 +1,19 @@
 SKIP: FAILED
 
-
-enable chromium_experimental_subgroups;
-
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
+RWByteAddressBuffer output : register(u0);
 
 struct ComputeInputs {
-  @builtin(subgroup_invocation_id)
-  subgroup_invocation_id : u32,
-  @builtin(subgroup_size)
-  subgroup_size : u32,
+  uint subgroup_invocation_id;
+  uint subgroup_size;
+};
+
+void main_inner(ComputeInputs inputs) {
+  output.Store((4u * inputs.subgroup_invocation_id), asuint(inputs.subgroup_size));
 }
 
-@compute @workgroup_size(1)
-fn main(inputs : ComputeInputs) {
-  output[inputs.subgroup_invocation_id] = inputs.subgroup_size;
+[numthreads(1, 1, 1)]
+void main() {
+  const ComputeInputs tint_symbol = {WaveGetLaneIndex(), WaveGetLaneCount()};
+  main_inner(tint_symbol);
+  return;
 }
-
-Failed to generate: shader_io/compute_subgroup_builtins_struct.wgsl:1:8 error: HLSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
-       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-