[hlsl-writer] Add support for subgroup inputs
Use CanonicalizeEntryPointIO to map these to HLSL's wave intrinsic
functions.
Bug: tint:2000
Change-Id: I943b4d8414a909c033386323eb12b2f8f21a5740
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/144041
Auto-Submit: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
index 88fd9c6..c0be95f 100644
--- a/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
+++ b/src/tint/lang/hlsl/writer/ast_printer/ast_printer.cc
@@ -1133,6 +1133,19 @@
}
}
+ if (auto* wave_intrinsic =
+ ast::GetAttribute<ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic>(
+ func->Declaration()->attributes)) {
+ switch (wave_intrinsic->op) {
+ case ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic::Op::kWaveGetLaneCount:
+ out << "WaveGetLaneCount()";
+ return true;
+ case ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic::Op::kWaveGetLaneIndex:
+ out << "WaveGetLaneIndex()";
+ return true;
+ }
+ }
+
out << func->Declaration()->name->symbol.Name() << "(";
bool first = true;
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
index af0b2f0..457516f 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.cc
@@ -32,6 +32,7 @@
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO);
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO::Config);
+TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::CanonicalizeEntryPointIO::HLSLWaveIntrinsic);
namespace tint::ast::transform {
@@ -142,6 +143,8 @@
std::unordered_set<std::string> input_names;
/// A map of cloned attribute to builtin value
Hashmap<const BuiltinAttribute*, core::BuiltinValue, 16> builtin_attrs;
+ /// A map of builtin values to HLSL wave intrinsic functions.
+ Hashmap<core::BuiltinValue, Symbol, 2> wave_intrinsics;
/// Constructor
/// @param context the clone context
@@ -224,6 +227,44 @@
return wrapper_struct_param_name;
}
+ /// Call a wave intrinsic function for the subgroup builtin contained in @p attrs, if any.
+ /// @param attrs the attribute list that may contain a subgroup builtin
+ /// @returns an expression that calls a HLSL wave intrinsic, or nullptr
+ const ast::CallExpression* CallWaveIntrinsic(VectorRef<const Attribute*> attrs) {
+ if (cfg.shader_style != ShaderStyle::kHlsl) {
+ return nullptr;
+ }
+
+ // Helper to make a wave intrinsic.
+ auto make_intrinsic = [&](const char* name, HLSLWaveIntrinsic::Op op) {
+ auto symbol = b.Symbols().New(name);
+ b.Func(symbol, Empty, b.ty.u32(), nullptr,
+ Vector{b.ASTNodes().Create<HLSLWaveIntrinsic>(b.ID(), b.AllocateNodeID(), op),
+ b.Disable(DisabledValidation::kFunctionHasNoBody)});
+ return symbol;
+ };
+
+ // Get or create the intrinsic function.
+ auto builtin = BuiltinOf(attrs);
+ auto intrinsic = wave_intrinsics.GetOrCreate(builtin, [&] {
+ if (builtin == core::BuiltinValue::kSubgroupInvocationId) {
+ return make_intrinsic("__WaveGetLaneIndex",
+ HLSLWaveIntrinsic::Op::kWaveGetLaneIndex);
+ }
+ if (builtin == core::BuiltinValue::kSubgroupSize) {
+ return make_intrinsic("__WaveGetLaneCount",
+ HLSLWaveIntrinsic::Op::kWaveGetLaneCount);
+ }
+ return Symbol();
+ });
+ if (!intrinsic) {
+ return nullptr;
+ }
+
+ // Call the intrinsic function.
+ return b.Call(intrinsic);
+ }
+
/// Add a shader input to the entry point.
/// @param name the name of the shader input
/// @param type the type of the shader input
@@ -349,6 +390,14 @@
/// that will be passed to the original function.
/// @param param the original function parameter
void ProcessNonStructParameter(const sem::Parameter* param) {
+ if (auto* wave_intrinsic_call = CallWaveIntrinsic(param->Declaration()->attributes)) {
+ inner_call_parameters.Push(wave_intrinsic_call);
+ for (auto* attr : param->Declaration()->attributes) {
+ ctx.Remove(param->Declaration()->attributes, attr);
+ }
+ return;
+ }
+
// Do not add interpolation attributes on vertex input
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
// Remove the shader IO attributes from the inner function parameter, and attach them to the
@@ -388,6 +437,11 @@
continue;
}
+ if (auto* wave_intrinsic_call = CallWaveIntrinsic(member->Declaration()->attributes)) {
+ inner_struct_values.Push(wave_intrinsic_call);
+ continue;
+ }
+
auto name = member->Name().Name();
auto attributes =
@@ -877,4 +931,24 @@
CanonicalizeEntryPointIO::Config::Config(const Config&) = default;
CanonicalizeEntryPointIO::Config::~Config() = default;
+CanonicalizeEntryPointIO::HLSLWaveIntrinsic::HLSLWaveIntrinsic(GenerationID pid, NodeID nid, Op o)
+ : Base(pid, nid, Empty), op(o) {}
+CanonicalizeEntryPointIO::HLSLWaveIntrinsic::~HLSLWaveIntrinsic() = default;
+std::string CanonicalizeEntryPointIO::HLSLWaveIntrinsic::InternalName() const {
+ StringStream ss;
+ switch (op) {
+ case Op::kWaveGetLaneCount:
+ return "intrinsic_wave_get_lane_count";
+ case Op::kWaveGetLaneIndex:
+ return "intrinsic_wave_get_lane_index";
+ }
+ return ss.str();
+}
+
+const CanonicalizeEntryPointIO::HLSLWaveIntrinsic*
+CanonicalizeEntryPointIO::HLSLWaveIntrinsic::Clone(ast::CloneContext& ctx) const {
+ return ctx.dst->ASTNodes().Create<CanonicalizeEntryPointIO::HLSLWaveIntrinsic>(
+ ctx.dst->ID(), ctx.dst->AllocateNodeID(), op);
+}
+
} // namespace tint::ast::transform
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
index d543de0..a6b3715 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io.h
@@ -15,6 +15,9 @@
#ifndef SRC_TINT_LANG_WGSL_AST_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
#define SRC_TINT_LANG_WGSL_AST_TRANSFORM_CANONICALIZE_ENTRY_POINT_IO_H_
+#include <string>
+
+#include "src/tint/lang/wgsl/ast/internal_attribute.h"
#include "src/tint/lang/wgsl/ast/transform/transform.h"
namespace tint::ast::transform {
@@ -123,6 +126,36 @@
const bool emit_vertex_point_size;
};
+ /// HLSLWaveIntrinsic is an InternalAttribute that is used to decorate a stub function so that
+ /// the HLSL backend transforms this into calls to Wave* intrinsic functions.
+ class HLSLWaveIntrinsic final : public Castable<HLSLWaveIntrinsic, InternalAttribute> {
+ public:
+ /// Wave intrinsic op
+ enum class Op {
+ kWaveGetLaneIndex,
+ kWaveGetLaneCount,
+ };
+
+ /// Constructor
+ /// @param pid the identifier of the program that owns this node
+ /// @param nid the unique node identifier
+ /// @param o the op of the wave intrinsic
+ HLSLWaveIntrinsic(GenerationID pid, NodeID nid, Op o);
+ /// Destructor
+ ~HLSLWaveIntrinsic() override;
+
+ /// @copydoc InternalAttribute::InternalName
+ std::string InternalName() const override;
+
+ /// Performs a deep clone of this object using the program::CloneContext `ctx`.
+ /// @param ctx the clone context
+ /// @return the newly cloned object
+ const HLSLWaveIntrinsic* Clone(CloneContext& ctx) const override;
+
+ /// The op of the intrinsic
+ const Op op;
+ };
+
/// Constructor
CanonicalizeEntryPointIO();
~CanonicalizeEntryPointIO() override;
diff --git a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
index b924acd..717f39b 100644
--- a/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
+++ b/src/tint/lang/wgsl/ast/transform/canonicalize_entry_point_io_test.cc
@@ -4171,5 +4171,88 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(CanonicalizeEntryPointIOTest, SubgroupBuiltins_Hlsl) {
+ auto* src = R"(
+enable chromium_experimental_subgroups;
+
+@compute @workgroup_size(64)
+fn frag_main(@builtin(subgroup_invocation_id) id : u32,
+ @builtin(subgroup_size) size : u32) {
+ let x = size - id;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_subgroups;
+
+@internal(intrinsic_wave_get_lane_index) @internal(disable_validation__function_has_no_body)
+fn __WaveGetLaneIndex() -> u32
+
+@internal(intrinsic_wave_get_lane_count) @internal(disable_validation__function_has_no_body)
+fn __WaveGetLaneCount() -> u32
+
+fn frag_main_inner(id : u32, size : u32) {
+ let x = (size - id);
+}
+
+@compute @workgroup_size(64)
+fn frag_main() {
+ frag_main_inner(__WaveGetLaneIndex(), __WaveGetLaneCount());
+}
+)";
+
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(CanonicalizeEntryPointIOTest, SubgroupBuiltinsStruct_Hlsl) {
+ auto* src = R"(
+enable chromium_experimental_subgroups;
+
+struct Inputs {
+ @builtin(subgroup_invocation_id) id : u32,
+ @builtin(subgroup_size) size : u32,
+}
+
+@compute @workgroup_size(64)
+fn frag_main(inputs : Inputs) {
+ let x = inputs.size - inputs.id;
+}
+)";
+
+ auto* expect = R"(
+enable chromium_experimental_subgroups;
+
+@internal(intrinsic_wave_get_lane_index) @internal(disable_validation__function_has_no_body)
+fn __WaveGetLaneIndex() -> u32
+
+@internal(intrinsic_wave_get_lane_count) @internal(disable_validation__function_has_no_body)
+fn __WaveGetLaneCount() -> u32
+
+struct Inputs {
+ id : u32,
+ size : u32,
+}
+
+fn frag_main_inner(inputs : Inputs) {
+ let x = (inputs.size - inputs.id);
+}
+
+@compute @workgroup_size(64)
+fn frag_main() {
+ frag_main_inner(Inputs(__WaveGetLaneIndex(), __WaveGetLaneCount()));
+}
+)";
+
+ DataMap data;
+ data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ auto got = Run<Unshadow, CanonicalizeEntryPointIO>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::ast::transform
diff --git a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.dxc.hlsl b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.dxc.hlsl
index c4e9adb..9cf0831 100644
--- a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.dxc.hlsl
+++ b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.dxc.hlsl
@@ -1,16 +1,11 @@
-SKIP: FAILED
+RWByteAddressBuffer output : register(u0);
-
-enable chromium_experimental_subgroups;
-
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
-
-@compute @workgroup_size(1)
-fn main(@builtin(subgroup_invocation_id) subgroup_invocation_id : u32, @builtin(subgroup_size) subgroup_size : u32) {
- output[subgroup_invocation_id] = subgroup_size;
+void main_inner(uint subgroup_invocation_id, uint subgroup_size) {
+ output.Store((4u * subgroup_invocation_id), asuint(subgroup_size));
}
-Failed to generate: shader_io/compute_subgroup_builtins.wgsl:1:8 error: HLSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
+[numthreads(1, 1, 1)]
+void main() {
+ main_inner(WaveGetLaneIndex(), WaveGetLaneCount());
+ return;
+}
diff --git a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.fxc.hlsl b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.fxc.hlsl
index c4e9adb..0bbfaea 100644
--- a/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.fxc.hlsl
+++ b/test/tint/shader_io/compute_subgroup_builtins.wgsl.expected.fxc.hlsl
@@ -1,16 +1,13 @@
SKIP: FAILED
+RWByteAddressBuffer output : register(u0);
-enable chromium_experimental_subgroups;
-
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
-
-@compute @workgroup_size(1)
-fn main(@builtin(subgroup_invocation_id) subgroup_invocation_id : u32, @builtin(subgroup_size) subgroup_size : u32) {
- output[subgroup_invocation_id] = subgroup_size;
+void main_inner(uint subgroup_invocation_id, uint subgroup_size) {
+ output.Store((4u * subgroup_invocation_id), asuint(subgroup_size));
}
-Failed to generate: shader_io/compute_subgroup_builtins.wgsl:1:8 error: HLSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
+[numthreads(1, 1, 1)]
+void main() {
+ main_inner(WaveGetLaneIndex(), WaveGetLaneCount());
+ return;
+}
diff --git a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.dxc.hlsl b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.dxc.hlsl
index 952abb3..8e1d0db 100644
--- a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.dxc.hlsl
+++ b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.dxc.hlsl
@@ -1,23 +1,17 @@
-SKIP: FAILED
-
-
-enable chromium_experimental_subgroups;
-
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
+RWByteAddressBuffer output : register(u0);
struct ComputeInputs {
- @builtin(subgroup_invocation_id)
- subgroup_invocation_id : u32,
- @builtin(subgroup_size)
- subgroup_size : u32,
+ uint subgroup_invocation_id;
+ uint subgroup_size;
+};
+
+void main_inner(ComputeInputs inputs) {
+ output.Store((4u * inputs.subgroup_invocation_id), asuint(inputs.subgroup_size));
}
-@compute @workgroup_size(1)
-fn main(inputs : ComputeInputs) {
- output[inputs.subgroup_invocation_id] = inputs.subgroup_size;
+[numthreads(1, 1, 1)]
+void main() {
+ const ComputeInputs tint_symbol = {WaveGetLaneIndex(), WaveGetLaneCount()};
+ main_inner(tint_symbol);
+ return;
}
-
-Failed to generate: shader_io/compute_subgroup_builtins_struct.wgsl:1:8 error: HLSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
diff --git a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.fxc.hlsl b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.fxc.hlsl
index 952abb3..830f68f 100644
--- a/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.fxc.hlsl
+++ b/test/tint/shader_io/compute_subgroup_builtins_struct.wgsl.expected.fxc.hlsl
@@ -1,23 +1,19 @@
SKIP: FAILED
-
-enable chromium_experimental_subgroups;
-
-@group(0) @binding(0) var<storage, read_write> output : array<u32>;
+RWByteAddressBuffer output : register(u0);
struct ComputeInputs {
- @builtin(subgroup_invocation_id)
- subgroup_invocation_id : u32,
- @builtin(subgroup_size)
- subgroup_size : u32,
+ uint subgroup_invocation_id;
+ uint subgroup_size;
+};
+
+void main_inner(ComputeInputs inputs) {
+ output.Store((4u * inputs.subgroup_invocation_id), asuint(inputs.subgroup_size));
}
-@compute @workgroup_size(1)
-fn main(inputs : ComputeInputs) {
- output[inputs.subgroup_invocation_id] = inputs.subgroup_size;
+[numthreads(1, 1, 1)]
+void main() {
+ const ComputeInputs tint_symbol = {WaveGetLaneIndex(), WaveGetLaneCount()};
+ main_inner(tint_symbol);
+ return;
}
-
-Failed to generate: shader_io/compute_subgroup_builtins_struct.wgsl:1:8 error: HLSL backend does not support extension 'chromium_experimental_subgroups'
-enable chromium_experimental_subgroups;
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-