Implement and test num_workgroups for all backends
For HLSL, use the new NumWorkgroupsFromUniform transform, and expose
the binding point to use for the generated uniform as a backend
option.
The MSL mapping is trivial, and it was already implemented for WGSL
and SPIR-V.
Bug: tint:752
Change-Id: I4bd37b5d26181629d72b152fe064a60caf8ecdc5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63962
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/docs/origin-trial-changes.md b/docs/origin-trial-changes.md
index 620223d..253c961 100644
--- a/docs/origin-trial-changes.md
+++ b/docs/origin-trial-changes.md
@@ -4,6 +4,7 @@
### New Features
* The size of an array can now be defined using a non-overridable module-scope constant
+* The `num_workgroups` builtin is now supported.
### Fixes
* Hex floats: issue an error when the magnitude is non-zero, and the exponent would cause
diff --git a/src/writer/hlsl/generator.cc b/src/writer/hlsl/generator.cc
index d2ee601..2f9a58a 100644
--- a/src/writer/hlsl/generator.cc
+++ b/src/writer/hlsl/generator.cc
@@ -28,7 +28,8 @@
Result result;
// Sanitize the program.
- auto sanitized_result = Sanitize(program, options.disable_workgroup_init);
+ auto sanitized_result = Sanitize(program, options.root_constant_binding_point,
+ options.disable_workgroup_init);
if (!sanitized_result.program.IsValid()) {
result.success = false;
result.error = sanitized_result.program.Diagnostics().str();
diff --git a/src/writer/hlsl/generator.h b/src/writer/hlsl/generator.h
index 1ff8d98..6930029 100644
--- a/src/writer/hlsl/generator.h
+++ b/src/writer/hlsl/generator.h
@@ -21,6 +21,7 @@
#include <vector>
#include "src/ast/pipeline_stage.h"
+#include "src/sem/binding_point.h"
#include "src/writer/text.h"
namespace tint {
@@ -36,6 +37,8 @@
/// Configuration options used for generating HLSL.
struct Options {
+ /// The binding point to use for information passed via root constants.
+ sem::BindingPoint root_constant_binding_point;
/// Set to `true` to disable workgroup memory zero initialization
bool disable_workgroup_init = false;
};
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index c70e331..94e5de5 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -51,6 +51,7 @@
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/loop_to_for_loop.h"
#include "src/transform/manager.h"
+#include "src/transform/num_workgroups_from_uniform.h"
#include "src/transform/pad_array_elements.h"
#include "src/transform/promote_initializers_to_const_var.h"
#include "src/transform/simplify.h"
@@ -113,7 +114,9 @@
} // namespace
-SanitizedResult Sanitize(const Program* in, bool disable_workgroup_init) {
+SanitizedResult Sanitize(const Program* in,
+ sem::BindingPoint root_constant_binding_point,
+ bool disable_workgroup_init) {
transform::Manager manager;
transform::DataMap data;
@@ -128,6 +131,10 @@
manager.Add<transform::ZeroInitWorkgroupMemory>();
}
manager.Add<transform::CanonicalizeEntryPointIO>();
+ // NumWorkgroupsFromUniform must come after CanonicalizeEntryPointIO, as it
+ // assumes that num_workgroups builtins only appear as struct members and are
+ // only accessed directly via member accessors.
+ manager.Add<transform::NumWorkgroupsFromUniform>();
manager.Add<transform::InlinePointerLets>();
// Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets.
manager.Add<transform::Simplify>();
@@ -147,6 +154,8 @@
data.Add<transform::CanonicalizeEntryPointIO::Config>(
transform::CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
+ data.Add<transform::NumWorkgroupsFromUniform::Config>(
+ root_constant_binding_point);
SanitizedResult result;
result.program = std::move(manager.Run(in, data).program);
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index 17a6674..5717d71 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -33,6 +33,7 @@
#include "src/ast/unary_op_expression.h"
#include "src/program_builder.h"
#include "src/scope_stack.h"
+#include "src/sem/binding_point.h"
#include "src/transform/decompose_memory_access.h"
#include "src/utils/hash.h"
#include "src/writer/text_generator.h"
@@ -55,9 +56,12 @@
};
/// Sanitize a program in preparation for generating HLSL.
+/// @param root_constant_binding_point the binding point to use for information
+/// that will be passed via root constants
/// @param disable_workgroup_init `true` to disable workgroup memory zero
/// @returns the sanitized program and any supplementary information
SanitizedResult Sanitize(const Program* program,
+ sem::BindingPoint root_constant_binding_point = {},
bool disable_workgroup_init = false);
/// Implementation class for HLSL generator
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 87d7b50..b790c44 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -1587,6 +1587,8 @@
return "thread_position_in_grid";
case ast::Builtin::kWorkgroupId:
return "threadgroup_position_in_grid";
+ case ast::Builtin::kNumWorkgroups:
+ return "threadgroups_per_grid";
case ast::Builtin::kSampleIndex:
return "sample_id";
case ast::Builtin::kSampleMask:
diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc
index f5abc64..5ce939e 100644
--- a/src/writer/msl/generator_impl_test.cc
+++ b/src/writer/msl/generator_impl_test.cc
@@ -75,6 +75,8 @@
"thread_position_in_grid"},
MslBuiltinData{ast::Builtin::kWorkgroupId,
"threadgroup_position_in_grid"},
+ MslBuiltinData{ast::Builtin::kNumWorkgroups,
+ "threadgroups_per_grid"},
MslBuiltinData{ast::Builtin::kSampleIndex, "sample_id"},
MslBuiltinData{ast::Builtin::kSampleMask, "sample_mask"},
MslBuiltinData{ast::Builtin::kPointSize, "point_size"}));
diff --git a/test/shader_io/compute_input_builtins.wgsl b/test/shader_io/compute_input_builtins.wgsl
index a90fc66..016ec31 100644
--- a/test/shader_io/compute_input_builtins.wgsl
+++ b/test/shader_io/compute_input_builtins.wgsl
@@ -4,11 +4,12 @@
[[builtin(local_invocation_index)]] local_invocation_index : u32,
[[builtin(global_invocation_id)]] global_invocation_id : vec3<u32>,
[[builtin(workgroup_id)]] workgroup_id : vec3<u32>,
- // TODO(crbug.com/tint/752): [[builtin(num_workgroups)]] num_workgroups : vec3<u32>;
+ [[builtin(num_workgroups)]] num_workgroups : vec3<u32>,
) {
let foo : u32 =
local_invocation_id.x +
- local_invocation_index +
- global_invocation_id.x +
- workgroup_id.x;
+ local_invocation_index +
+ global_invocation_id.x +
+ workgroup_id.x +
+ num_workgroups.x;
}
diff --git a/test/shader_io/compute_input_builtins.wgsl.expected.hlsl b/test/shader_io/compute_input_builtins.wgsl.expected.hlsl
index cb71cd9..d6ab749 100644
--- a/test/shader_io/compute_input_builtins.wgsl.expected.hlsl
+++ b/test/shader_io/compute_input_builtins.wgsl.expected.hlsl
@@ -1,3 +1,7 @@
+cbuffer cbuffer_tint_symbol_3 : register(b0, space0) {
+ uint4 tint_symbol_3[1];
+};
+
struct tint_symbol_1 {
uint3 local_invocation_id : SV_GroupThreadID;
uint local_invocation_index : SV_GroupIndex;
@@ -5,12 +9,12 @@
uint3 workgroup_id : SV_GroupID;
};
-void main_inner(uint3 local_invocation_id, uint local_invocation_index, uint3 global_invocation_id, uint3 workgroup_id) {
- const uint foo = (((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x);
+void main_inner(uint3 local_invocation_id, uint local_invocation_index, uint3 global_invocation_id, uint3 workgroup_id, uint3 num_workgroups) {
+ const uint foo = ((((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x) + num_workgroups.x);
}
[numthreads(1, 1, 1)]
void main(tint_symbol_1 tint_symbol) {
- main_inner(tint_symbol.local_invocation_id, tint_symbol.local_invocation_index, tint_symbol.global_invocation_id, tint_symbol.workgroup_id);
+ main_inner(tint_symbol.local_invocation_id, tint_symbol.local_invocation_index, tint_symbol.global_invocation_id, tint_symbol.workgroup_id, tint_symbol_3[0].xyz);
return;
}
diff --git a/test/shader_io/compute_input_builtins.wgsl.expected.msl b/test/shader_io/compute_input_builtins.wgsl.expected.msl
index 60095f3..67ecc1b 100644
--- a/test/shader_io/compute_input_builtins.wgsl.expected.msl
+++ b/test/shader_io/compute_input_builtins.wgsl.expected.msl
@@ -1,12 +1,12 @@
#include <metal_stdlib>
using namespace metal;
-void tint_symbol_inner(uint3 local_invocation_id, uint local_invocation_index, uint3 global_invocation_id, uint3 workgroup_id) {
- uint const foo = (((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x);
+void tint_symbol_inner(uint3 local_invocation_id, uint local_invocation_index, uint3 global_invocation_id, uint3 workgroup_id, uint3 num_workgroups) {
+ uint const foo = ((((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x) + num_workgroups.x);
}
-kernel void tint_symbol(uint3 local_invocation_id [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], uint3 global_invocation_id [[thread_position_in_grid]], uint3 workgroup_id [[threadgroup_position_in_grid]]) {
- tint_symbol_inner(local_invocation_id, local_invocation_index, global_invocation_id, workgroup_id);
+kernel void tint_symbol(uint3 local_invocation_id [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], uint3 global_invocation_id [[thread_position_in_grid]], uint3 workgroup_id [[threadgroup_position_in_grid]], uint3 num_workgroups [[threadgroups_per_grid]]) {
+ tint_symbol_inner(local_invocation_id, local_invocation_index, global_invocation_id, workgroup_id, num_workgroups);
return;
}
diff --git a/test/shader_io/compute_input_builtins.wgsl.expected.spvasm b/test/shader_io/compute_input_builtins.wgsl.expected.spvasm
index a1dddd0..6c5e5d6 100644
--- a/test/shader_io/compute_input_builtins.wgsl.expected.spvasm
+++ b/test/shader_io/compute_input_builtins.wgsl.expected.spvasm
@@ -1,26 +1,29 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
-; Bound: 31
+; Bound: 36
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
- OpEntryPoint GLCompute %main "main" %local_invocation_id_1 %local_invocation_index_1 %global_invocation_id_1 %workgroup_id_1
+ OpEntryPoint GLCompute %main "main" %local_invocation_id_1 %local_invocation_index_1 %global_invocation_id_1 %workgroup_id_1 %num_workgroups_1
OpExecutionMode %main LocalSize 1 1 1
OpName %local_invocation_id_1 "local_invocation_id_1"
OpName %local_invocation_index_1 "local_invocation_index_1"
OpName %global_invocation_id_1 "global_invocation_id_1"
OpName %workgroup_id_1 "workgroup_id_1"
+ OpName %num_workgroups_1 "num_workgroups_1"
OpName %main_inner "main_inner"
OpName %local_invocation_id "local_invocation_id"
OpName %local_invocation_index "local_invocation_index"
OpName %global_invocation_id "global_invocation_id"
OpName %workgroup_id "workgroup_id"
+ OpName %num_workgroups "num_workgroups"
OpName %main "main"
OpDecorate %local_invocation_id_1 BuiltIn LocalInvocationId
OpDecorate %local_invocation_index_1 BuiltIn LocalInvocationIndex
OpDecorate %global_invocation_id_1 BuiltIn GlobalInvocationId
OpDecorate %workgroup_id_1 BuiltIn WorkgroupId
+ OpDecorate %num_workgroups_1 BuiltIn NumWorkgroups
%uint = OpTypeInt 32 0
%v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
@@ -29,29 +32,34 @@
%local_invocation_index_1 = OpVariable %_ptr_Input_uint Input
%global_invocation_id_1 = OpVariable %_ptr_Input_v3uint Input
%workgroup_id_1 = OpVariable %_ptr_Input_v3uint Input
+%num_workgroups_1 = OpVariable %_ptr_Input_v3uint Input
%void = OpTypeVoid
- %9 = OpTypeFunction %void %v3uint %uint %v3uint %v3uint
- %23 = OpTypeFunction %void
- %main_inner = OpFunction %void None %9
+ %10 = OpTypeFunction %void %v3uint %uint %v3uint %v3uint %v3uint
+ %27 = OpTypeFunction %void
+ %main_inner = OpFunction %void None %10
%local_invocation_id = OpFunctionParameter %v3uint
%local_invocation_index = OpFunctionParameter %uint
%global_invocation_id = OpFunctionParameter %v3uint
%workgroup_id = OpFunctionParameter %v3uint
- %16 = OpLabel
- %17 = OpCompositeExtract %uint %local_invocation_id 0
- %18 = OpIAdd %uint %17 %local_invocation_index
- %19 = OpCompositeExtract %uint %global_invocation_id 0
- %20 = OpIAdd %uint %18 %19
- %21 = OpCompositeExtract %uint %workgroup_id 0
+%num_workgroups = OpFunctionParameter %v3uint
+ %18 = OpLabel
+ %19 = OpCompositeExtract %uint %local_invocation_id 0
+ %20 = OpIAdd %uint %19 %local_invocation_index
+ %21 = OpCompositeExtract %uint %global_invocation_id 0
%22 = OpIAdd %uint %20 %21
+ %23 = OpCompositeExtract %uint %workgroup_id 0
+ %24 = OpIAdd %uint %22 %23
+ %25 = OpCompositeExtract %uint %num_workgroups 0
+ %26 = OpIAdd %uint %24 %25
OpReturn
OpFunctionEnd
- %main = OpFunction %void None %23
- %25 = OpLabel
- %27 = OpLoad %v3uint %local_invocation_id_1
- %28 = OpLoad %uint %local_invocation_index_1
- %29 = OpLoad %v3uint %global_invocation_id_1
- %30 = OpLoad %v3uint %workgroup_id_1
- %26 = OpFunctionCall %void %main_inner %27 %28 %29 %30
+ %main = OpFunction %void None %27
+ %29 = OpLabel
+ %31 = OpLoad %v3uint %local_invocation_id_1
+ %32 = OpLoad %uint %local_invocation_index_1
+ %33 = OpLoad %v3uint %global_invocation_id_1
+ %34 = OpLoad %v3uint %workgroup_id_1
+ %35 = OpLoad %v3uint %num_workgroups_1
+ %30 = OpFunctionCall %void %main_inner %31 %32 %33 %34 %35
OpReturn
OpFunctionEnd
diff --git a/test/shader_io/compute_input_builtins.wgsl.expected.wgsl b/test/shader_io/compute_input_builtins.wgsl.expected.wgsl
index 518b32e..42c8adb 100644
--- a/test/shader_io/compute_input_builtins.wgsl.expected.wgsl
+++ b/test/shader_io/compute_input_builtins.wgsl.expected.wgsl
@@ -1,4 +1,4 @@
[[stage(compute), workgroup_size(1)]]
-fn main([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index : u32, [[builtin(global_invocation_id)]] global_invocation_id : vec3<u32>, [[builtin(workgroup_id)]] workgroup_id : vec3<u32>) {
- let foo : u32 = (((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x);
+fn main([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index : u32, [[builtin(global_invocation_id)]] global_invocation_id : vec3<u32>, [[builtin(workgroup_id)]] workgroup_id : vec3<u32>, [[builtin(num_workgroups)]] num_workgroups : vec3<u32>) {
+ let foo : u32 = ((((local_invocation_id.x + local_invocation_index) + global_invocation_id.x) + workgroup_id.x) + num_workgroups.x);
}
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl b/test/shader_io/compute_input_builtins_struct.wgsl
index 61b8f64..c9691c9 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl
+++ b/test/shader_io/compute_input_builtins_struct.wgsl
@@ -3,14 +3,15 @@
[[builtin(local_invocation_index)]] local_invocation_index : u32;
[[builtin(global_invocation_id)]] global_invocation_id : vec3<u32>;
[[builtin(workgroup_id)]] workgroup_id : vec3<u32>;
- // TODO(crbug.com/tint/752): [[builtin(num_workgroups)]] num_workgroups : vec3<u32>;
+ [[builtin(num_workgroups)]] num_workgroups : vec3<u32>;
};
[[stage(compute), workgroup_size(1)]]
fn main(inputs : ComputeInputs) {
let foo : u32 =
inputs.local_invocation_id.x +
- inputs.local_invocation_index +
- inputs.global_invocation_id.x +
- inputs.workgroup_id.x;
+ inputs.local_invocation_index +
+ inputs.global_invocation_id.x +
+ inputs.workgroup_id.x +
+ inputs.num_workgroups.x;
}
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl.expected.hlsl b/test/shader_io/compute_input_builtins_struct.wgsl.expected.hlsl
index 8e00493..783117e 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl.expected.hlsl
+++ b/test/shader_io/compute_input_builtins_struct.wgsl.expected.hlsl
@@ -1,8 +1,13 @@
+cbuffer cbuffer_tint_symbol_3 : register(b0, space0) {
+ uint4 tint_symbol_3[1];
+};
+
struct ComputeInputs {
uint3 local_invocation_id;
uint local_invocation_index;
uint3 global_invocation_id;
uint3 workgroup_id;
+ uint3 num_workgroups;
};
struct tint_symbol_1 {
uint3 local_invocation_id : SV_GroupThreadID;
@@ -12,12 +17,12 @@
};
void main_inner(ComputeInputs inputs) {
- const uint foo = (((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x);
+ const uint foo = ((((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x) + inputs.num_workgroups.x);
}
[numthreads(1, 1, 1)]
void main(tint_symbol_1 tint_symbol) {
- const ComputeInputs tint_symbol_2 = {tint_symbol.local_invocation_id, tint_symbol.local_invocation_index, tint_symbol.global_invocation_id, tint_symbol.workgroup_id};
- main_inner(tint_symbol_2);
+ const ComputeInputs tint_symbol_5 = {tint_symbol.local_invocation_id, tint_symbol.local_invocation_index, tint_symbol.global_invocation_id, tint_symbol.workgroup_id, tint_symbol_3[0].xyz};
+ main_inner(tint_symbol_5);
return;
}
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl.expected.msl b/test/shader_io/compute_input_builtins_struct.wgsl.expected.msl
index 56db89c..34db10e 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl.expected.msl
+++ b/test/shader_io/compute_input_builtins_struct.wgsl.expected.msl
@@ -6,14 +6,15 @@
uint local_invocation_index;
uint3 global_invocation_id;
uint3 workgroup_id;
+ uint3 num_workgroups;
};
void tint_symbol_inner(ComputeInputs inputs) {
- uint const foo = (((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x);
+ uint const foo = ((((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x) + inputs.num_workgroups.x);
}
-kernel void tint_symbol(uint3 local_invocation_id [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], uint3 global_invocation_id [[thread_position_in_grid]], uint3 workgroup_id [[threadgroup_position_in_grid]]) {
- ComputeInputs const tint_symbol_1 = {.local_invocation_id=local_invocation_id, .local_invocation_index=local_invocation_index, .global_invocation_id=global_invocation_id, .workgroup_id=workgroup_id};
+kernel void tint_symbol(uint3 local_invocation_id [[thread_position_in_threadgroup]], uint local_invocation_index [[thread_index_in_threadgroup]], uint3 global_invocation_id [[thread_position_in_grid]], uint3 workgroup_id [[threadgroup_position_in_grid]], uint3 num_workgroups [[threadgroups_per_grid]]) {
+ ComputeInputs const tint_symbol_1 = {.local_invocation_id=local_invocation_id, .local_invocation_index=local_invocation_index, .global_invocation_id=global_invocation_id, .workgroup_id=workgroup_id, .num_workgroups=num_workgroups};
tint_symbol_inner(tint_symbol_1);
return;
}
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl.expected.spvasm b/test/shader_io/compute_input_builtins_struct.wgsl.expected.spvasm
index d270e75..9e26e28 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl.expected.spvasm
+++ b/test/shader_io/compute_input_builtins_struct.wgsl.expected.spvasm
@@ -1,21 +1,23 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
-; Bound: 34
+; Bound: 39
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
- OpEntryPoint GLCompute %main "main" %local_invocation_id_1 %local_invocation_index_1 %global_invocation_id_1 %workgroup_id_1
+ OpEntryPoint GLCompute %main "main" %local_invocation_id_1 %local_invocation_index_1 %global_invocation_id_1 %workgroup_id_1 %num_workgroups_1
OpExecutionMode %main LocalSize 1 1 1
OpName %local_invocation_id_1 "local_invocation_id_1"
OpName %local_invocation_index_1 "local_invocation_index_1"
OpName %global_invocation_id_1 "global_invocation_id_1"
OpName %workgroup_id_1 "workgroup_id_1"
+ OpName %num_workgroups_1 "num_workgroups_1"
OpName %ComputeInputs "ComputeInputs"
OpMemberName %ComputeInputs 0 "local_invocation_id"
OpMemberName %ComputeInputs 1 "local_invocation_index"
OpMemberName %ComputeInputs 2 "global_invocation_id"
OpMemberName %ComputeInputs 3 "workgroup_id"
+ OpMemberName %ComputeInputs 4 "num_workgroups"
OpName %main_inner "main_inner"
OpName %inputs "inputs"
OpName %main "main"
@@ -23,10 +25,12 @@
OpDecorate %local_invocation_index_1 BuiltIn LocalInvocationIndex
OpDecorate %global_invocation_id_1 BuiltIn GlobalInvocationId
OpDecorate %workgroup_id_1 BuiltIn WorkgroupId
+ OpDecorate %num_workgroups_1 BuiltIn NumWorkgroups
OpMemberDecorate %ComputeInputs 0 Offset 0
OpMemberDecorate %ComputeInputs 1 Offset 12
OpMemberDecorate %ComputeInputs 2 Offset 16
OpMemberDecorate %ComputeInputs 3 Offset 32
+ OpMemberDecorate %ComputeInputs 4 Offset 48
%uint = OpTypeInt 32 0
%v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
@@ -35,32 +39,37 @@
%local_invocation_index_1 = OpVariable %_ptr_Input_uint Input
%global_invocation_id_1 = OpVariable %_ptr_Input_v3uint Input
%workgroup_id_1 = OpVariable %_ptr_Input_v3uint Input
+%num_workgroups_1 = OpVariable %_ptr_Input_v3uint Input
%void = OpTypeVoid
-%ComputeInputs = OpTypeStruct %v3uint %uint %v3uint %v3uint
- %9 = OpTypeFunction %void %ComputeInputs
- %25 = OpTypeFunction %void
- %main_inner = OpFunction %void None %9
+%ComputeInputs = OpTypeStruct %v3uint %uint %v3uint %v3uint %v3uint
+ %10 = OpTypeFunction %void %ComputeInputs
+ %29 = OpTypeFunction %void
+ %main_inner = OpFunction %void None %10
%inputs = OpFunctionParameter %ComputeInputs
- %14 = OpLabel
- %15 = OpCompositeExtract %v3uint %inputs 0
- %16 = OpCompositeExtract %uint %15 0
- %17 = OpCompositeExtract %uint %inputs 1
- %18 = OpIAdd %uint %16 %17
- %19 = OpCompositeExtract %v3uint %inputs 2
- %20 = OpCompositeExtract %uint %19 0
- %21 = OpIAdd %uint %18 %20
- %22 = OpCompositeExtract %v3uint %inputs 3
- %23 = OpCompositeExtract %uint %22 0
- %24 = OpIAdd %uint %21 %23
+ %15 = OpLabel
+ %16 = OpCompositeExtract %v3uint %inputs 0
+ %17 = OpCompositeExtract %uint %16 0
+ %18 = OpCompositeExtract %uint %inputs 1
+ %19 = OpIAdd %uint %17 %18
+ %20 = OpCompositeExtract %v3uint %inputs 2
+ %21 = OpCompositeExtract %uint %20 0
+ %22 = OpIAdd %uint %19 %21
+ %23 = OpCompositeExtract %v3uint %inputs 3
+ %24 = OpCompositeExtract %uint %23 0
+ %25 = OpIAdd %uint %22 %24
+ %26 = OpCompositeExtract %v3uint %inputs 4
+ %27 = OpCompositeExtract %uint %26 0
+ %28 = OpIAdd %uint %25 %27
OpReturn
OpFunctionEnd
- %main = OpFunction %void None %25
- %27 = OpLabel
- %29 = OpLoad %v3uint %local_invocation_id_1
- %30 = OpLoad %uint %local_invocation_index_1
- %31 = OpLoad %v3uint %global_invocation_id_1
- %32 = OpLoad %v3uint %workgroup_id_1
- %33 = OpCompositeConstruct %ComputeInputs %29 %30 %31 %32
- %28 = OpFunctionCall %void %main_inner %33
+ %main = OpFunction %void None %29
+ %31 = OpLabel
+ %33 = OpLoad %v3uint %local_invocation_id_1
+ %34 = OpLoad %uint %local_invocation_index_1
+ %35 = OpLoad %v3uint %global_invocation_id_1
+ %36 = OpLoad %v3uint %workgroup_id_1
+ %37 = OpLoad %v3uint %num_workgroups_1
+ %38 = OpCompositeConstruct %ComputeInputs %33 %34 %35 %36 %37
+ %32 = OpFunctionCall %void %main_inner %38
OpReturn
OpFunctionEnd
diff --git a/test/shader_io/compute_input_builtins_struct.wgsl.expected.wgsl b/test/shader_io/compute_input_builtins_struct.wgsl.expected.wgsl
index 99b0c4f..23d354a 100644
--- a/test/shader_io/compute_input_builtins_struct.wgsl.expected.wgsl
+++ b/test/shader_io/compute_input_builtins_struct.wgsl.expected.wgsl
@@ -7,9 +7,11 @@
global_invocation_id : vec3<u32>;
[[builtin(workgroup_id)]]
workgroup_id : vec3<u32>;
+ [[builtin(num_workgroups)]]
+ num_workgroups : vec3<u32>;
};
[[stage(compute), workgroup_size(1)]]
fn main(inputs : ComputeInputs) {
- let foo : u32 = (((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x);
+ let foo : u32 = ((((inputs.local_invocation_id.x + inputs.local_invocation_index) + inputs.global_invocation_id.x) + inputs.workgroup_id.x) + inputs.num_workgroups.x);
}