| // Copyright 2018 The Dawn Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "dawn_native/ComputePassEncoder.h" |
| |
| #include "dawn_native/BindGroup.h" |
| #include "dawn_native/BindGroupLayout.h" |
| #include "dawn_native/Buffer.h" |
| #include "dawn_native/CommandEncoder.h" |
| #include "dawn_native/CommandValidation.h" |
| #include "dawn_native/Commands.h" |
| #include "dawn_native/ComputePipeline.h" |
| #include "dawn_native/Device.h" |
| #include "dawn_native/InternalPipelineStore.h" |
| #include "dawn_native/ObjectType_autogen.h" |
| #include "dawn_native/PassResourceUsageTracker.h" |
| #include "dawn_native/QuerySet.h" |
| #include "dawn_native/utils/WGPUHelpers.h" |
| |
| namespace dawn_native { |
| |
| namespace { |
| |
| ResultOrError<ComputePipelineBase*> GetOrCreateIndirectDispatchValidationPipeline( |
| DeviceBase* device) { |
| InternalPipelineStore* store = device->GetInternalPipelineStore(); |
| |
| if (store->dispatchIndirectValidationPipeline != nullptr) { |
| return store->dispatchIndirectValidationPipeline.Get(); |
| } |
| |
| // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this |
| // shader in various failure modes. |
| Ref<ShaderModuleBase> shaderModule; |
| DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"( |
| [[block]] struct UniformParams { |
| maxComputeWorkgroupsPerDimension: u32; |
| clientOffsetInU32: u32; |
| }; |
| |
| [[block]] struct IndirectParams { |
| data: array<u32>; |
| }; |
| |
| [[block]] struct ValidatedParams { |
| data: array<u32, 3>; |
| }; |
| |
| [[group(0), binding(0)]] var<uniform> uniformParams: UniformParams; |
| [[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams; |
| [[group(0), binding(2)]] var<storage, write> validatedParams: ValidatedParams; |
| |
| [[stage(compute), workgroup_size(1, 1, 1)]] |
| fn main() { |
| for (var i = 0u; i < 3u; i = i + 1u) { |
| var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i]; |
| if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) { |
| numWorkgroups = 0u; |
| } |
| validatedParams.data[i] = numWorkgroups; |
| } |
| } |
| )")); |
| |
| Ref<BindGroupLayoutBase> bindGroupLayout; |
| DAWN_TRY_ASSIGN( |
| bindGroupLayout, |
| utils::MakeBindGroupLayout( |
| device, |
| { |
| {0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform}, |
| {1, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding}, |
| {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage}, |
| }, |
| /* allowInternalBinding */ true)); |
| |
| Ref<PipelineLayoutBase> pipelineLayout; |
| DAWN_TRY_ASSIGN(pipelineLayout, |
| utils::MakeBasicPipelineLayout(device, bindGroupLayout)); |
| |
| ComputePipelineDescriptor computePipelineDescriptor = {}; |
| computePipelineDescriptor.layout = pipelineLayout.Get(); |
| computePipelineDescriptor.compute.module = shaderModule.Get(); |
| computePipelineDescriptor.compute.entryPoint = "main"; |
| |
| DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline, |
| device->CreateComputePipeline(&computePipelineDescriptor)); |
| |
| return store->dispatchIndirectValidationPipeline.Get(); |
| } |
| |
| } // namespace |
| |
| ComputePassEncoder::ComputePassEncoder(DeviceBase* device, |
| CommandEncoder* commandEncoder, |
| EncodingContext* encodingContext) |
| : ProgrammablePassEncoder(device, encodingContext), mCommandEncoder(commandEncoder) { |
| } |
| |
| ComputePassEncoder::ComputePassEncoder(DeviceBase* device, |
| CommandEncoder* commandEncoder, |
| EncodingContext* encodingContext, |
| ErrorTag errorTag) |
| : ProgrammablePassEncoder(device, encodingContext, errorTag), |
| mCommandEncoder(commandEncoder) { |
| } |
| |
| ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device, |
| CommandEncoder* commandEncoder, |
| EncodingContext* encodingContext) { |
| return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError); |
| } |
| |
| ObjectType ComputePassEncoder::GetType() const { |
| return ObjectType::ComputePassEncoder; |
| } |
| |
| void ComputePassEncoder::APIEndPass() { |
| if (mEncodingContext->TryEncode( |
| this, |
| [&](CommandAllocator* allocator) -> MaybeError { |
| if (IsValidationEnabled()) { |
| DAWN_TRY(ValidateProgrammableEncoderEnd()); |
| } |
| |
| allocator->Allocate<EndComputePassCmd>(Command::EndComputePass); |
| |
| return {}; |
| }, |
| "encoding %s.EndPass().", this)) { |
| mEncodingContext->ExitComputePass(this, mUsageTracker.AcquireResourceUsage()); |
| } |
| } |
| |
| void ComputePassEncoder::APIDispatch(uint32_t x, uint32_t y, uint32_t z) { |
| mEncodingContext->TryEncode( |
| this, |
| [&](CommandAllocator* allocator) -> MaybeError { |
| if (IsValidationEnabled()) { |
| DAWN_TRY(mCommandBufferState.ValidateCanDispatch()); |
| |
| uint32_t workgroupsPerDimension = |
| GetDevice()->GetLimits().v1.maxComputeWorkgroupsPerDimension; |
| |
| DAWN_INVALID_IF( |
| x > workgroupsPerDimension, |
| "Dispatch size X (%u) exceeds max compute workgroups per dimension (%u).", |
| x, workgroupsPerDimension); |
| |
| DAWN_INVALID_IF( |
| y > workgroupsPerDimension, |
| "Dispatch size Y (%u) exceeds max compute workgroups per dimension (%u).", |
| y, workgroupsPerDimension); |
| |
| DAWN_INVALID_IF( |
| z > workgroupsPerDimension, |
| "Dispatch size Z (%u) exceeds max compute workgroups per dimension (%u).", |
| z, workgroupsPerDimension); |
| } |
| |
| // Record the synchronization scope for Dispatch, which is just the current |
| // bindgroups. |
| AddDispatchSyncScope(); |
| |
| DispatchCmd* dispatch = allocator->Allocate<DispatchCmd>(Command::Dispatch); |
| dispatch->x = x; |
| dispatch->y = y; |
| dispatch->z = z; |
| |
| return {}; |
| }, |
| "encoding %s.Dispatch(%u, %u, %u).", this, x, y, z); |
| } |
| |
| ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> |
| ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer, |
| uint64_t indirectOffset) { |
| DeviceBase* device = GetDevice(); |
| auto* const store = device->GetInternalPipelineStore(); |
| |
| Ref<ComputePipelineBase> validationPipeline; |
| DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device)); |
| |
| Ref<BindGroupLayoutBase> layout; |
| DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0)); |
| |
| uint32_t storageBufferOffsetAlignment = |
| device->GetLimits().v1.minStorageBufferOffsetAlignment; |
| |
| // Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|. |
| const uint32_t clientOffsetFromAlignedBoundary = |
| indirectOffset % storageBufferOffsetAlignment; |
| const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary; |
| const uint64_t clientIndirectBindingOffset = clientOffsetAlignedDown; |
| |
| // Let the size of the binding be the additional offset, plus the size. |
| const uint64_t clientIndirectBindingSize = |
| kDispatchIndirectSize + clientOffsetFromAlignedBoundary; |
| |
| struct UniformParams { |
| uint32_t maxComputeWorkgroupsPerDimension; |
| uint32_t clientOffsetInU32; |
| }; |
| |
| // Create a uniform buffer to hold parameters for the shader. |
| Ref<BufferBase> uniformBuffer; |
| { |
| UniformParams params; |
| params.maxComputeWorkgroupsPerDimension = |
| device->GetLimits().v1.maxComputeWorkgroupsPerDimension; |
| params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t); |
| |
| DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData( |
| device, wgpu::BufferUsage::Uniform, {params})); |
| } |
| |
| // Reserve space in the scratch buffer to hold the validated indirect params. |
| ScratchBuffer& scratchBuffer = store->scratchIndirectStorage; |
| DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize)); |
| Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer(); |
| |
| Ref<BindGroupBase> validationBindGroup; |
| DAWN_TRY_ASSIGN( |
| validationBindGroup, |
| utils::MakeBindGroup( |
| device, layout, |
| { |
| {0, uniformBuffer}, |
| {1, indirectBuffer, clientIndirectBindingOffset, clientIndirectBindingSize}, |
| {2, validatedIndirectBuffer, 0, kDispatchIndirectSize}, |
| })); |
| |
| // Issue commands to validate the indirect buffer. |
| APISetPipeline(validationPipeline.Get()); |
| APISetBindGroup(0, validationBindGroup.Get()); |
| APIDispatch(1); |
| |
| // Return the new indirect buffer and indirect buffer offset. |
| return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0)); |
| } |
| |
| void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer, |
| uint64_t indirectOffset) { |
| mEncodingContext->TryEncode( |
| this, |
| [&](CommandAllocator* allocator) -> MaybeError { |
| if (IsValidationEnabled()) { |
| DAWN_TRY(GetDevice()->ValidateObject(indirectBuffer)); |
| DAWN_TRY(ValidateCanUseAs(indirectBuffer, wgpu::BufferUsage::Indirect)); |
| DAWN_TRY(mCommandBufferState.ValidateCanDispatch()); |
| |
| DAWN_INVALID_IF(indirectOffset % 4 != 0, |
| "Indirect offset (%u) is not a multiple of 4.", indirectOffset); |
| |
| DAWN_INVALID_IF( |
| indirectOffset >= indirectBuffer->GetSize() || |
| indirectOffset + kDispatchIndirectSize > indirectBuffer->GetSize(), |
| "Indirect offset (%u) and dispatch size (%u) exceeds the indirect buffer " |
| "size (%u).", |
| indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize()); |
| } |
| |
| SyncScopeUsageTracker scope; |
| scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect); |
| mUsageTracker.AddReferencedBuffer(indirectBuffer); |
| // TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer| |
| // is needed for correct usage validation even though it will only be bound for |
| // storage. This will unecessarily transition the |indirectBuffer| in |
| // the backend. |
| |
| Ref<BufferBase> indirectBufferRef = indirectBuffer; |
| if (IsValidationEnabled()) { |
| // Save the previous command buffer state so it can be restored after the |
| // validation inserts additional commands. |
| CommandBufferStateTracker previousState = mCommandBufferState; |
| |
| // Validate each indirect dispatch with a single dispatch to copy the indirect |
| // buffer params into a scratch buffer if they're valid, and otherwise zero them |
| // out. We could consider moving the validation earlier in the pass after the |
| // last point the indirect buffer was used with writable usage, as well as batch |
| // validation for multiple dispatches into one, but inserting commands at |
| // arbitrary points in the past is not possible right now. |
| DAWN_TRY_ASSIGN( |
| std::tie(indirectBufferRef, indirectOffset), |
| ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset)); |
| |
| // Restore the state. |
| RestoreCommandBufferState(std::move(previousState)); |
| |
| // |indirectBufferRef| was replaced with a scratch buffer. Add it to the |
| // synchronization scope. |
| ASSERT(indirectBufferRef.Get() != indirectBuffer); |
| scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect); |
| mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get()); |
| } |
| |
| AddDispatchSyncScope(std::move(scope)); |
| |
| DispatchIndirectCmd* dispatch = |
| allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect); |
| dispatch->indirectBuffer = std::move(indirectBufferRef); |
| dispatch->indirectOffset = indirectOffset; |
| return {}; |
| }, |
| "encoding %s.DispatchIndirect(%s, %u).", this, indirectBuffer, indirectOffset); |
| } |
| |
| void ComputePassEncoder::APISetPipeline(ComputePipelineBase* pipeline) { |
| mEncodingContext->TryEncode( |
| this, |
| [&](CommandAllocator* allocator) -> MaybeError { |
| if (IsValidationEnabled()) { |
| DAWN_TRY(GetDevice()->ValidateObject(pipeline)); |
| } |
| |
| mCommandBufferState.SetComputePipeline(pipeline); |
| |
| SetComputePipelineCmd* cmd = |
| allocator->Allocate<SetComputePipelineCmd>(Command::SetComputePipeline); |
| cmd->pipeline = pipeline; |
| |
| return {}; |
| }, |
| "encoding %s.SetPipeline(%s).", this, pipeline); |
| } |
| |
| void ComputePassEncoder::APISetBindGroup(uint32_t groupIndexIn, |
| BindGroupBase* group, |
| uint32_t dynamicOffsetCount, |
| const uint32_t* dynamicOffsets) { |
| mEncodingContext->TryEncode( |
| this, |
| [&](CommandAllocator* allocator) -> MaybeError { |
| BindGroupIndex groupIndex(groupIndexIn); |
| |
| if (IsValidationEnabled()) { |
| DAWN_TRY(ValidateSetBindGroup(groupIndex, group, dynamicOffsetCount, |
| dynamicOffsets)); |
| } |
| |
| mUsageTracker.AddResourcesReferencedByBindGroup(group); |
| RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount, |
| dynamicOffsets); |
| mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount, |
| dynamicOffsets); |
| |
| return {}; |
| }, |
| "encoding %s.SetBindGroup(%u, %s, %u, ...).", this, groupIndexIn, group, |
| dynamicOffsetCount); |
| } |
| |
| void ComputePassEncoder::APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex) { |
| mEncodingContext->TryEncode( |
| this, |
| [&](CommandAllocator* allocator) -> MaybeError { |
| if (IsValidationEnabled()) { |
| DAWN_TRY(GetDevice()->ValidateObject(querySet)); |
| DAWN_TRY(ValidateTimestampQuery(querySet, queryIndex)); |
| } |
| |
| mCommandEncoder->TrackQueryAvailability(querySet, queryIndex); |
| |
| WriteTimestampCmd* cmd = |
| allocator->Allocate<WriteTimestampCmd>(Command::WriteTimestamp); |
| cmd->querySet = querySet; |
| cmd->queryIndex = queryIndex; |
| |
| return {}; |
| }, |
| "encoding %s.WriteTimestamp(%s, %u).", this, querySet, queryIndex); |
| } |
| |
| void ComputePassEncoder::AddDispatchSyncScope(SyncScopeUsageTracker scope) { |
| PipelineLayoutBase* layout = mCommandBufferState.GetPipelineLayout(); |
| for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) { |
| scope.AddBindGroup(mCommandBufferState.GetBindGroup(i)); |
| } |
| mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage()); |
| } |
| |
| void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) { |
| // Encode commands for the backend to restore the pipeline and bind groups. |
| if (state.HasPipeline()) { |
| APISetPipeline(state.GetComputePipeline()); |
| } |
| for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) { |
| BindGroupBase* bg = state.GetBindGroup(i); |
| if (bg != nullptr) { |
| const std::vector<uint32_t>& offsets = state.GetDynamicOffsets(i); |
| if (offsets.empty()) { |
| APISetBindGroup(static_cast<uint32_t>(i), bg); |
| } else { |
| APISetBindGroup(static_cast<uint32_t>(i), bg, offsets.size(), offsets.data()); |
| } |
| } |
| } |
| |
| // Restore the frontend state tracking information. |
| mCommandBufferState = std::move(state); |
| } |
| |
| CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() { |
| return &mCommandBufferState; |
| } |
| |
| } // namespace dawn_native |