D3D12: Support [[num_workgroups]] for Dispatch
This patch implements [[num_workgroups]] on the API side for
Dispatch() calls by setting num_workgroups.xyz as root constants.
This patch also adds a temporary validation that on D3D12 backend
using a compute pipeline with [[num_workgroups]] in a
DispatchIndirect call is not supported.
BUG=dawn:839
TEST=dawn_end2end_tests
Change-Id: Iaee2ffd162e9420e4e80944fbb222f10a4600c6a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/66580
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index bd3989b..cfb45a1 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -677,6 +677,8 @@
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
+
+ metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
}
if (metadata->stage == SingleShaderStage::Vertex) {
diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h
index 8f8081c..070d001 100644
--- a/src/dawn_native/ShaderModule.h
+++ b/src/dawn_native/ShaderModule.h
@@ -204,6 +204,8 @@
// Store overridableConstants from tint program
std::unordered_map<std::string, OverridableConstant> overridableConstants;
+
+ bool usesNumWorkgroups = false;
};
class ShaderModuleBase : public ApiObjectBase, public CachedObject {
diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
index 4bdc3b0..dc32d10 100644
--- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp
+++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp
@@ -254,6 +254,18 @@
return {};
}
+ void RecordNumWorkgroupsForDispatch(ID3D12GraphicsCommandList* commandList,
+ ComputePipeline* pipeline,
+ DispatchCmd* dispatch) {
+ if (!pipeline->UsesNumWorkgroups()) {
+ return;
+ }
+
+ PipelineLayout* layout = ToBackend(pipeline->GetLayout());
+ commandList->SetComputeRoot32BitConstants(layout->GetNumWorkgroupsParameterIndex(), 3,
+ dispatch, 0);
+ }
+
// Records the necessary barriers for a synchronization scope using the resource usage
// data pre-computed in the frontend. Also performs lazy initialization if required.
// Returns whether any UAV are used in the synchronization scope.
@@ -1030,6 +1042,7 @@
ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
Command type;
+ ComputePipeline* lastPipeline = nullptr;
while (mCommands.NextCommandId(&type)) {
switch (type) {
case Command::Dispatch: {
@@ -1045,6 +1058,7 @@
resourceUsages.dispatchUsages[currentDispatch]);
DAWN_TRY(bindingTracker->Apply(commandContext));
+ RecordNumWorkgroupsForDispatch(commandList, lastPipeline, dispatch);
commandList->Dispatch(dispatch->x, dispatch->y, dispatch->z);
currentDispatch++;
break;
@@ -1052,6 +1066,14 @@
case Command::DispatchIndirect: {
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
+
+ // TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls
+ if (lastPipeline->UsesNumWorkgroups()) {
+ return DAWN_VALIDATION_ERROR(
+ "Using a compute pipeline with [[num_workgroups]] in a "
+ "DispatchIndirect call is not implemented");
+ }
+
Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
TransitionAndClearForSyncScope(commandContext,
@@ -1078,6 +1100,7 @@
commandList->SetPipelineState(pipeline->GetPipelineState());
bindingTracker->OnSetPipeline(pipeline);
+ lastPipeline = pipeline;
break;
}
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
index 29ae08a..54ddc7e 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
@@ -84,4 +84,8 @@
CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
}
+ bool ComputePipeline::UsesNumWorkgroups() const {
+ return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
+ }
+
}} // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.h b/src/dawn_native/d3d12/ComputePipelineD3D12.h
index 7c7a02d..b652026 100644
--- a/src/dawn_native/d3d12/ComputePipelineD3D12.h
+++ b/src/dawn_native/d3d12/ComputePipelineD3D12.h
@@ -40,6 +40,8 @@
// Dawn API
void SetLabelImpl() override;
+ bool UsesNumWorkgroups() const;
+
private:
~ComputePipeline() override;
using ComputePipelineBase::ComputePipelineBase;
diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
index 372b61b..1a512fa 100644
--- a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
+++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp
@@ -174,6 +174,21 @@
// would need to be updated often
rootParameters.emplace_back(indexOffsetConstants);
+ // Always allocate 3 constants for num_workgroups_x, num_workgroups_y and num_workgroups_z
+ // for Dispatch calls
+ // NOTE: We should consider delaying root signature creation until we know how many values
+ // we need
+ D3D12_ROOT_PARAMETER numWorkgroupsConstants{};
+ numWorkgroupsConstants.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
+ numWorkgroupsConstants.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS;
+ numWorkgroupsConstants.Constants.Num32BitValues = 3;
+ numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
+ numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
+ mNumWorkgroupsParamterIndex = rootParameters.size();
+ // NOTE: We should consider moving this entry to earlier in the root signature since
+ // dispatch sizes would need to be updated often
+ rootParameters.emplace_back(numWorkgroupsConstants);
+
D3D12_ROOT_SIGNATURE_DESC rootSignatureDescriptor;
rootSignatureDescriptor.NumParameters = rootParameters.size();
rootSignatureDescriptor.pParameters = rootParameters.data();
@@ -230,7 +245,7 @@
}
uint32_t PipelineLayout::GetFirstIndexOffsetRegisterSpace() const {
- return kReservedRegisterSpace;
+ return kFirstIndexOffsetRegisterSpace;
}
uint32_t PipelineLayout::GetFirstIndexOffsetShaderRegister() const {
@@ -240,4 +255,16 @@
uint32_t PipelineLayout::GetFirstIndexOffsetParameterIndex() const {
return mFirstIndexOffsetParameterIndex;
}
+
+ uint32_t PipelineLayout::GetNumWorkgroupsRegisterSpace() const {
+ return kNumWorkgroupsRegisterSpace;
+ }
+
+ uint32_t PipelineLayout::GetNumWorkgroupsShaderRegister() const {
+ return kNumWorkgroupsBaseRegister;
+ }
+
+ uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
+ return mNumWorkgroupsParamterIndex;
+ }
}} // namespace dawn_native::d3d12
diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.h b/src/dawn_native/d3d12/PipelineLayoutD3D12.h
index b1efc0d..cf52f06 100644
--- a/src/dawn_native/d3d12/PipelineLayoutD3D12.h
+++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.h
@@ -26,6 +26,9 @@
// We reserve a register space that a user cannot use.
static constexpr uint32_t kReservedRegisterSpace = kMaxBindGroups + 1;
static constexpr uint32_t kFirstOffsetInfoBaseRegister = 0;
+ static constexpr uint32_t kFirstIndexOffsetRegisterSpace = kReservedRegisterSpace;
+ static constexpr uint32_t kNumWorkgroupsRegisterSpace = kReservedRegisterSpace + 1;
+ static constexpr uint32_t kNumWorkgroupsBaseRegister = 0;
class Device;
@@ -46,6 +49,10 @@
uint32_t GetFirstIndexOffsetShaderRegister() const;
uint32_t GetFirstIndexOffsetParameterIndex() const;
+ uint32_t GetNumWorkgroupsRegisterSpace() const;
+ uint32_t GetNumWorkgroupsShaderRegister() const;
+ uint32_t GetNumWorkgroupsParameterIndex() const;
+
ID3D12RootSignature* GetRootSignature() const;
private:
@@ -59,6 +66,7 @@
kMaxBindGroups>
mDynamicRootParameterIndices;
uint32_t mFirstIndexOffsetParameterIndex;
+ uint32_t mNumWorkgroupsParamterIndex;
ComPtr<ID3D12RootSignature> mRootSignature;
};
diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
index 2dafc96..89b5825 100644
--- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
+++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp
@@ -106,6 +106,9 @@
tint::transform::BindingRemapper::BindingPoints bindingPoints;
tint::transform::BindingRemapper::AccessControls accessControls;
bool isRobustnessEnabled;
+ bool usesNumWorkgroups;
+ uint32_t numWorkgroupsRegisterSpace;
+ uint32_t numWorkgroupsShaderRegister;
// FXC/DXC common inputs
bool disableWorkgroupInit;
@@ -125,7 +128,7 @@
uint32_t compileFlags,
const Device* device,
const tint::Program* program,
- const BindingInfoArray& moduleBindingInfo) {
+ const EntryPointMetadata& entryPoint) {
Compiler compiler;
uint64_t dxcVersion = 0;
if (device->IsToggleEnabled(Toggle::UseDXC)) {
@@ -145,6 +148,7 @@
// Tint AST to make the "bindings" decoration match the offset chosen by
// d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
// assigned to each interface variable.
+ const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
const auto& groupBindingInfo = moduleBindingInfo[group];
@@ -189,6 +193,9 @@
request.isRobustnessEnabled = device->IsRobustnessEnabled();
request.disableWorkgroupInit =
device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
+ request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
+ request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
+ request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
request.deviceInfo = &device->GetDeviceInfo();
@@ -234,6 +241,10 @@
stream << " accessControls=";
Serialize(stream, accessControls);
+ stream << " useNumWorkgroups=" << usesNumWorkgroups;
+ stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
+ stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
+
stream << " shaderModel=" << deviceInfo->shaderModel;
stream << " disableWorkgroupInit=" << disableWorkgroupInit;
stream << " isRobustnessEnabled=" << isRobustnessEnabled;
@@ -423,6 +434,10 @@
tint::writer::hlsl::Options options;
options.disable_workgroup_init = request.disableWorkgroupInit;
+ if (request.usesNumWorkgroups) {
+ options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace;
+ options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister;
+ }
auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
if (!result.success) {
errorStream << "Generator: " << result.error << std::endl;
@@ -547,9 +562,9 @@
}
ShaderCompilationRequest request;
- DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(
- entryPointName, stage, layout, compileFlags, device, program,
- GetEntryPoint(entryPointName).bindings));
+ DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout,
+ compileFlags, device, program,
+ GetEntryPoint(entryPointName)));
PersistentCacheKey shaderCacheKey;
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
diff --git a/src/tests/end2end/ComputeDispatchTests.cpp b/src/tests/end2end/ComputeDispatchTests.cpp
index 7ec4076..1a8b163 100644
--- a/src/tests/end2end/ComputeDispatchTests.cpp
+++ b/src/tests/end2end/ComputeDispatchTests.cpp
@@ -26,9 +26,30 @@
DawnTest::SetUp();
// Write workgroup number into the output buffer if we saw the biggest dispatch
- // This is a workaround since D3D12 doesn't have gl_NumWorkGroups
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
- wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
+ wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"(
+ [[block]] struct OutputBuf {
+ workGroups : vec3<u32>;
+ };
+
+ [[group(0), binding(0)]] var<storage, read_write> output : OutputBuf;
+
+ [[stage(compute), workgroup_size(1, 1, 1)]]
+ fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>,
+ [[builtin(num_workgroups)]] dispatch : vec3<u32>) {
+ if (dispatch.x == 0u || dispatch.y == 0u || dispatch.z == 0u) {
+ output.workGroups = vec3<u32>(0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu);
+ return;
+ }
+
+ if (all(GlobalInvocationID == dispatch - vec3<u32>(1u, 1u, 1u))) {
+ output.workGroups = dispatch;
+ }
+ })");
+
+ // TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports
+ // [[num_workgroups]] for indirect dispatch.
+ wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"(
[[block]] struct InputBuf {
expectedDispatch : vec3<u32>;
};
@@ -54,9 +75,12 @@
})");
wgpu::ComputePipelineDescriptor csDesc;
- csDesc.compute.module = module;
+ csDesc.compute.module = moduleForDispatch;
csDesc.compute.entryPoint = "main";
- pipeline = device.CreateComputePipeline(&csDesc);
+ pipelineForDispatch = device.CreateComputePipeline(&csDesc);
+
+ csDesc.compute.module = moduleForDispatchIndirect;
+ pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc);
}
void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
@@ -66,23 +90,18 @@
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
kSentinelData);
- std::initializer_list<uint32_t> expectedBufferData{x, y, z};
- wgpu::Buffer expectedBuffer = utils::CreateBufferFromData<uint32_t>(
- device, wgpu::BufferUsage::Uniform, expectedBufferData);
-
// Set up bind group and issue dispatch
wgpu::BindGroup bindGroup =
- utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
+ utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0),
{
- {0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
- {1, dst, 0, 3 * sizeof(uint32_t)},
+ {0, dst, 0, 3 * sizeof(uint32_t)},
});
wgpu::CommandBuffer commands;
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
- pass.SetPipeline(pipeline);
+ pass.SetPipeline(pipelineForDispatch);
pass.SetBindGroup(0, bindGroup);
pass.Dispatch(x, y, z);
pass.EndPass();
@@ -93,7 +112,7 @@
queue.Submit(1, &commands);
std::vector<uint32_t> expected =
- x == 0 || y == 0 || z == 0 ? kSentinelData : expectedBufferData;
+ x == 0 || y == 0 || z == 0 ? kSentinelData : std::initializer_list<uint32_t>{x, y, z};
// Verify the dispatch got called if all group counts are not zero
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
@@ -118,7 +137,7 @@
// Set up bind group and issue dispatch
wgpu::BindGroup bindGroup =
- utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
+ utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0),
{
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
{1, dst, 0, 3 * sizeof(uint32_t)},
@@ -128,7 +147,7 @@
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
- pass.SetPipeline(pipeline);
+ pass.SetPipeline(pipelineForDispatchIndirect);
pass.SetBindGroup(0, bindGroup);
pass.DispatchIndirect(indirectBuffer, indirectOffset);
pass.EndPass();
@@ -153,7 +172,8 @@
}
private:
- wgpu::ComputePipeline pipeline;
+ wgpu::ComputePipeline pipelineForDispatch;
+ wgpu::ComputePipeline pipelineForDispatchIndirect;
};
// Test basic direct