Implement GPU-based validation for dispatchIndirect
Bug: dawn:1039
Change-Id: I1b77244d33b178c8e4d4b7d72dc038ccb9d65c48
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/67142
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn_native/CommandBuffer.cpp b/src/dawn_native/CommandBuffer.cpp
index 18fef0d..a964f40 100644
--- a/src/dawn_native/CommandBuffer.cpp
+++ b/src/dawn_native/CommandBuffer.cpp
@@ -73,6 +73,10 @@
return mResourceUsages;
}
+ CommandIterator* CommandBufferBase::GetCommandIteratorForTesting() {
+ return &mCommands;
+ }
+
bool IsCompleteSubresourceCopiedTo(const TextureBase* texture,
const Extent3D copySize,
const uint32_t mipLevel) {
diff --git a/src/dawn_native/CommandBuffer.h b/src/dawn_native/CommandBuffer.h
index 2800929..c6d47ae 100644
--- a/src/dawn_native/CommandBuffer.h
+++ b/src/dawn_native/CommandBuffer.h
@@ -43,6 +43,8 @@
const CommandBufferResourceUsage& GetResourceUsages() const;
+ CommandIterator* GetCommandIteratorForTesting();
+
protected:
~CommandBufferBase() override;
diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp
index 5210936..45892a6 100644
--- a/src/dawn_native/CommandBufferStateTracker.cpp
+++ b/src/dawn_native/CommandBufferStateTracker.cpp
@@ -17,8 +17,10 @@
#include "common/Assert.h"
#include "common/BitSetIterator.h"
#include "dawn_native/BindGroup.h"
+#include "dawn_native/ComputePassEncoder.h"
#include "dawn_native/ComputePipeline.h"
#include "dawn_native/Forward.h"
+#include "dawn_native/ObjectType_autogen.h"
#include "dawn_native/PipelineLayout.h"
#include "dawn_native/RenderPipeline.h"
@@ -83,13 +85,15 @@
MaybeError CommandBufferStateTracker::ValidateBufferInRangeForVertexBuffer(
uint32_t vertexCount,
uint32_t firstVertex) {
+ RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
+
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
vertexBufferSlotsUsedAsVertexBuffer =
- mLastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer();
+ lastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer();
for (auto usedSlotVertex : IterateBitSet(vertexBufferSlotsUsedAsVertexBuffer)) {
const VertexBufferInfo& vertexBuffer =
- mLastRenderPipeline->GetVertexBuffer(usedSlotVertex);
+ lastRenderPipeline->GetVertexBuffer(usedSlotVertex);
uint64_t arrayStride = vertexBuffer.arrayStride;
uint64_t bufferSize = mVertexBufferSizes[usedSlotVertex];
@@ -120,13 +124,15 @@
MaybeError CommandBufferStateTracker::ValidateBufferInRangeForInstanceBuffer(
uint32_t instanceCount,
uint32_t firstInstance) {
+ RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
+
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
vertexBufferSlotsUsedAsInstanceBuffer =
- mLastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer();
+ lastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer();
for (auto usedSlotInstance : IterateBitSet(vertexBufferSlotsUsedAsInstanceBuffer)) {
const VertexBufferInfo& vertexBuffer =
- mLastRenderPipeline->GetVertexBuffer(usedSlotInstance);
+ lastRenderPipeline->GetVertexBuffer(usedSlotInstance);
uint64_t arrayStride = vertexBuffer.arrayStride;
uint64_t bufferSize = mVertexBufferSizes[usedSlotInstance];
if (arrayStride == 0) {
@@ -209,18 +215,19 @@
}
if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) {
- ASSERT(mLastRenderPipeline != nullptr);
+ RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& requiredVertexBuffers =
- mLastRenderPipeline->GetVertexBufferSlotsUsed();
+ lastRenderPipeline->GetVertexBufferSlotsUsed();
if (IsSubset(requiredVertexBuffers, mVertexBufferSlotsUsed)) {
mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS);
}
}
if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && mIndexBufferSet) {
- if (!IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) ||
- mIndexFormat == mLastRenderPipeline->GetStripIndexFormat()) {
+ RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
+ if (!IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) ||
+ mIndexFormat == lastRenderPipeline->GetStripIndexFormat()) {
mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER);
}
}
@@ -234,12 +241,13 @@
if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_INDEX_BUFFER])) {
DAWN_INVALID_IF(!mIndexBufferSet, "Index buffer was not set.");
- wgpu::IndexFormat pipelineIndexFormat = mLastRenderPipeline->GetStripIndexFormat();
+ RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
+ wgpu::IndexFormat pipelineIndexFormat = lastRenderPipeline->GetStripIndexFormat();
DAWN_INVALID_IF(
- IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) &&
+ IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) &&
mIndexFormat != pipelineIndexFormat,
"Strip index format (%s) of %s does not match index buffer format (%s).",
- pipelineIndexFormat, mLastRenderPipeline, mIndexFormat);
+ pipelineIndexFormat, lastRenderPipeline, mIndexFormat);
// The chunk of code above should be similar to the one in |RecomputeLazyAspects|.
// It returns the first invalid state found. We shouldn't be able to reach this line
@@ -251,7 +259,7 @@
// TODO(dawn:563): Indicate which slots were not set.
DAWN_INVALID_IF(aspects[VALIDATION_ASPECT_VERTEX_BUFFERS],
- "Vertex buffer slots required by %s were not set.", mLastRenderPipeline);
+ "Vertex buffer slots required by %s were not set.", GetRenderPipeline());
if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_BIND_GROUPS])) {
for (BindGroupIndex i : IterateBitSet(mLastPipelineLayout->GetBindGroupLayoutsMask())) {
@@ -290,12 +298,15 @@
}
void CommandBufferStateTracker::SetRenderPipeline(RenderPipelineBase* pipeline) {
- mLastRenderPipeline = pipeline;
SetPipelineCommon(pipeline);
}
- void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup) {
+ void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index,
+ BindGroupBase* bindgroup,
+ uint32_t dynamicOffsetCount,
+ const uint32_t* dynamicOffsets) {
mBindgroups[index] = bindgroup;
+ mDynamicOffsets[index].assign(dynamicOffsets, dynamicOffsets + dynamicOffsetCount);
mAspects.reset(VALIDATION_ASPECT_BIND_GROUPS);
}
@@ -311,8 +322,9 @@
}
void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) {
- mLastPipelineLayout = pipeline->GetLayout();
- mMinBufferSizes = &pipeline->GetMinBufferSizes();
+ mLastPipeline = pipeline;
+ mLastPipelineLayout = pipeline != nullptr ? pipeline->GetLayout() : nullptr;
+ mMinBufferSizes = pipeline != nullptr ? &pipeline->GetMinBufferSizes() : nullptr;
mAspects.set(VALIDATION_ASPECT_PIPELINE);
@@ -324,6 +336,25 @@
return mBindgroups[index];
}
+ const std::vector<uint32_t>& CommandBufferStateTracker::GetDynamicOffsets(
+ BindGroupIndex index) const {
+ return mDynamicOffsets[index];
+ }
+
+ bool CommandBufferStateTracker::HasPipeline() const {
+ return mLastPipeline != nullptr;
+ }
+
+ RenderPipelineBase* CommandBufferStateTracker::GetRenderPipeline() const {
+ ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::RenderPipeline);
+ return static_cast<RenderPipelineBase*>(mLastPipeline);
+ }
+
+ ComputePipelineBase* CommandBufferStateTracker::GetComputePipeline() const {
+ ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::ComputePipeline);
+ return static_cast<ComputePipelineBase*>(mLastPipeline);
+ }
+
PipelineLayoutBase* CommandBufferStateTracker::GetPipelineLayout() const {
return mLastPipelineLayout;
}
diff --git a/src/dawn_native/CommandBufferStateTracker.h b/src/dawn_native/CommandBufferStateTracker.h
index 0a6c587..5686956 100644
--- a/src/dawn_native/CommandBufferStateTracker.h
+++ b/src/dawn_native/CommandBufferStateTracker.h
@@ -38,7 +38,10 @@
// State-modifying methods
void SetComputePipeline(ComputePipelineBase* pipeline);
void SetRenderPipeline(RenderPipelineBase* pipeline);
- void SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup);
+ void SetBindGroup(BindGroupIndex index,
+ BindGroupBase* bindgroup,
+ uint32_t dynamicOffsetCount,
+ const uint32_t* dynamicOffsets);
void SetIndexBuffer(wgpu::IndexFormat format, uint64_t size);
void SetVertexBuffer(VertexBufferSlot slot, uint64_t size);
@@ -46,6 +49,10 @@
using ValidationAspects = std::bitset<kNumAspects>;
BindGroupBase* GetBindGroup(BindGroupIndex index) const;
+ const std::vector<uint32_t>& GetDynamicOffsets(BindGroupIndex index) const;
+ bool HasPipeline() const;
+ RenderPipelineBase* GetRenderPipeline() const;
+ ComputePipelineBase* GetComputePipeline() const;
PipelineLayoutBase* GetPipelineLayout() const;
wgpu::IndexFormat GetIndexFormat() const;
uint64_t GetIndexBufferSize() const;
@@ -60,6 +67,7 @@
ValidationAspects mAspects;
ityp::array<BindGroupIndex, BindGroupBase*, kMaxBindGroups> mBindgroups = {};
+ ityp::array<BindGroupIndex, std::vector<uint32_t>, kMaxBindGroups> mDynamicOffsets = {};
ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mVertexBufferSlotsUsed;
bool mIndexBufferSet = false;
wgpu::IndexFormat mIndexFormat;
@@ -68,7 +76,7 @@
ityp::array<VertexBufferSlot, uint64_t, kMaxVertexBuffers> mVertexBufferSizes = {};
PipelineLayoutBase* mLastPipelineLayout = nullptr;
- RenderPipelineBase* mLastRenderPipeline = nullptr;
+ PipelineBase* mLastPipeline = nullptr;
const RequiredBufferSizes* mMinBufferSizes = nullptr;
};
diff --git a/src/dawn_native/ComputePassEncoder.cpp b/src/dawn_native/ComputePassEncoder.cpp
index 1aa4845..05c68fb 100644
--- a/src/dawn_native/ComputePassEncoder.cpp
+++ b/src/dawn_native/ComputePassEncoder.cpp
@@ -14,18 +14,107 @@
#include "dawn_native/ComputePassEncoder.h"
+#include "dawn_native/BindGroup.h"
+#include "dawn_native/BindGroupLayout.h"
#include "dawn_native/Buffer.h"
#include "dawn_native/CommandEncoder.h"
#include "dawn_native/CommandValidation.h"
#include "dawn_native/Commands.h"
#include "dawn_native/ComputePipeline.h"
#include "dawn_native/Device.h"
+#include "dawn_native/InternalPipelineStore.h"
#include "dawn_native/ObjectType_autogen.h"
#include "dawn_native/PassResourceUsageTracker.h"
#include "dawn_native/QuerySet.h"
namespace dawn_native {
+ namespace {
+
+ ResultOrError<ComputePipelineBase*> GetOrCreateIndirectDispatchValidationPipeline(
+ DeviceBase* device) {
+ InternalPipelineStore* store = device->GetInternalPipelineStore();
+
+ if (store->dispatchIndirectValidationPipeline != nullptr) {
+ return store->dispatchIndirectValidationPipeline.Get();
+ }
+
+ ShaderModuleDescriptor descriptor;
+ ShaderModuleWGSLDescriptor wgslDesc;
+ descriptor.nextInChain = reinterpret_cast<ChainedStruct*>(&wgslDesc);
+
+ // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
+ // shader in various failure modes.
+ wgslDesc.source = R"(
+ [[block]] struct UniformParams {
+ maxComputeWorkgroupsPerDimension: u32;
+ clientOffsetInU32: u32;
+ };
+
+ [[block]] struct IndirectParams {
+ data: array<u32>;
+ };
+
+ [[block]] struct ValidatedParams {
+ data: array<u32, 3>;
+ };
+
+ [[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
+ [[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams;
+ [[group(0), binding(2)]] var<storage, write> validatedParams: ValidatedParams;
+
+ [[stage(compute), workgroup_size(1, 1, 1)]]
+ fn main() {
+ for (var i = 0u; i < 3u; i = i + 1u) {
+ var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
+ if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
+ numWorkgroups = 0u;
+ }
+ validatedParams.data[i] = numWorkgroups;
+ }
+ }
+ )";
+
+ Ref<ShaderModuleBase> shaderModule;
+ DAWN_TRY_ASSIGN(shaderModule, device->CreateShaderModule(&descriptor));
+
+ std::array<BindGroupLayoutEntry, 3> entries;
+ entries[0].binding = 0;
+ entries[0].visibility = wgpu::ShaderStage::Compute;
+ entries[0].buffer.type = wgpu::BufferBindingType::Uniform;
+ entries[1].binding = 1;
+ entries[1].visibility = wgpu::ShaderStage::Compute;
+ entries[1].buffer.type = kInternalStorageBufferBinding;
+ entries[2].binding = 2;
+ entries[2].visibility = wgpu::ShaderStage::Compute;
+ entries[2].buffer.type = wgpu::BufferBindingType::Storage;
+
+ BindGroupLayoutDescriptor bindGroupLayoutDescriptor;
+ bindGroupLayoutDescriptor.entryCount = entries.size();
+ bindGroupLayoutDescriptor.entries = entries.data();
+ Ref<BindGroupLayoutBase> bindGroupLayout;
+ DAWN_TRY_ASSIGN(bindGroupLayout,
+ device->CreateBindGroupLayout(&bindGroupLayoutDescriptor, true));
+
+ PipelineLayoutDescriptor pipelineDescriptor;
+ pipelineDescriptor.bindGroupLayoutCount = 1;
+ pipelineDescriptor.bindGroupLayouts = &bindGroupLayout.Get();
+ Ref<PipelineLayoutBase> pipelineLayout;
+ DAWN_TRY_ASSIGN(pipelineLayout, device->CreatePipelineLayout(&pipelineDescriptor));
+
+ ComputePipelineDescriptor computePipelineDescriptor = {};
+ computePipelineDescriptor.layout = pipelineLayout.Get();
+ computePipelineDescriptor.compute.module = shaderModule.Get();
+ computePipelineDescriptor.compute.entryPoint = "main";
+
+ DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline,
+ device->CreateComputePipeline(&computePipelineDescriptor));
+
+ return store->dispatchIndirectValidationPipeline.Get();
+ }
+
+ } // namespace
+
ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext)
@@ -107,6 +196,95 @@
"encoding Dispatch (x: %u, y: %u, z: %u)", x, y, z);
}
+ ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
+ ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer,
+ uint64_t indirectOffset) {
+ DeviceBase* device = GetDevice();
+ auto* const store = device->GetInternalPipelineStore();
+
+ Ref<ComputePipelineBase> validationPipeline;
+ DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device));
+
+ Ref<BindGroupLayoutBase> layout;
+ DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0));
+
+ uint32_t storageBufferOffsetAlignment =
+ device->GetLimits().v1.minStorageBufferOffsetAlignment;
+
+ std::array<BindGroupEntry, 3> bindings;
+
+ // Storage binding holding the client's indirect buffer.
+ BindGroupEntry& clientIndirectBinding = bindings[0];
+ clientIndirectBinding.binding = 1;
+ clientIndirectBinding.buffer = indirectBuffer;
+
+ // Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|.
+ const uint32_t clientOffsetFromAlignedBoundary =
+ indirectOffset % storageBufferOffsetAlignment;
+ const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary;
+ clientIndirectBinding.offset = clientOffsetAlignedDown;
+
+ // Let the size of the binding be the additional offset, plus the size.
+ clientIndirectBinding.size = kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
+
+ struct UniformParams {
+ uint32_t maxComputeWorkgroupsPerDimension;
+ uint32_t clientOffsetInU32;
+ };
+
+ // Create a uniform buffer to hold parameters for the shader.
+ Ref<BufferBase> uniformBuffer;
+ {
+ BufferDescriptor uniformDesc = {};
+ uniformDesc.size = sizeof(UniformParams);
+ uniformDesc.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst;
+ uniformDesc.mappedAtCreation = true;
+ DAWN_TRY_ASSIGN(uniformBuffer, device->CreateBuffer(&uniformDesc));
+
+ UniformParams* params = static_cast<UniformParams*>(
+ uniformBuffer->GetMappedRange(0, sizeof(UniformParams)));
+ params->maxComputeWorkgroupsPerDimension =
+ device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
+ params->clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
+ uniformBuffer->Unmap();
+ }
+
+ // Uniform buffer binding pointing to the uniform parameters.
+ BindGroupEntry& uniformBinding = bindings[1];
+ uniformBinding.binding = 0;
+ uniformBinding.buffer = uniformBuffer.Get();
+ uniformBinding.offset = 0;
+ uniformBinding.size = sizeof(UniformParams);
+
+ // Reserve space in the scratch buffer to hold the validated indirect params.
+ ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
+ DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize));
+ Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
+
+ // Binding for the validated indirect params.
+ BindGroupEntry& validatedParamsBinding = bindings[2];
+ validatedParamsBinding.binding = 2;
+ validatedParamsBinding.buffer = validatedIndirectBuffer.Get();
+ validatedParamsBinding.offset = 0;
+ validatedParamsBinding.size = kDispatchIndirectSize;
+
+ BindGroupDescriptor bindGroupDescriptor = {};
+ bindGroupDescriptor.layout = layout.Get();
+ bindGroupDescriptor.entryCount = bindings.size();
+ bindGroupDescriptor.entries = bindings.data();
+
+ Ref<BindGroupBase> validationBindGroup;
+ DAWN_TRY_ASSIGN(validationBindGroup, device->CreateBindGroup(&bindGroupDescriptor));
+
+ // Issue commands to validate the indirect buffer.
+ APISetPipeline(validationPipeline.Get());
+ APISetBindGroup(0, validationBindGroup.Get());
+ APIDispatch(1);
+
+ // Return the new indirect buffer and indirect buffer offset.
+ return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
+ }
+
void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer,
uint64_t indirectOffset) {
mEncodingContext->TryEncode(
@@ -136,18 +314,46 @@
indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize());
}
- // Record the synchronization scope for Dispatch, both the bindgroups and the
- // indirect buffer.
SyncScopeUsageTracker scope;
scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
mUsageTracker.AddReferencedBuffer(indirectBuffer);
+ // TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer|
+ // is needed for correct usage validation even though it will only be bound for
+ // storage. This will unecessarily transition the |indirectBuffer| in
+ // the backend.
+
+ Ref<BufferBase> indirectBufferRef = indirectBuffer;
+ if (IsValidationEnabled()) {
+ // Save the previous command buffer state so it can be restored after the
+ // validation inserts additional commands.
+ CommandBufferStateTracker previousState = mCommandBufferState;
+
+ // Validate each indirect dispatch with a single dispatch to copy the indirect
+ // buffer params into a scratch buffer if they're valid, and otherwise zero them
+ // out. We could consider moving the validation earlier in the pass after the
+ // last point the indirect buffer was used with writable usage, as well as batch
+ // validation for multiple dispatches into one, but inserting commands at
+ // arbitrary points in the past is not possible right now.
+ DAWN_TRY_ASSIGN(
+ std::tie(indirectBufferRef, indirectOffset),
+ ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset));
+
+ // Restore the state.
+ RestoreCommandBufferState(std::move(previousState));
+
+ // |indirectBufferRef| was replaced with a scratch buffer. Add it to the
+ // synchronization scope.
+ ASSERT(indirectBufferRef.Get() != indirectBuffer);
+ scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
+ mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
+ }
+
AddDispatchSyncScope(std::move(scope));
DispatchIndirectCmd* dispatch =
allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect);
- dispatch->indirectBuffer = indirectBuffer;
+ dispatch->indirectBuffer = std::move(indirectBufferRef);
dispatch->indirectOffset = indirectOffset;
-
return {};
},
"encoding DispatchIndirect with %s", indirectBuffer);
@@ -187,10 +393,10 @@
}
mUsageTracker.AddResourcesReferencedByBindGroup(group);
-
RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
dynamicOffsets);
- mCommandBufferState.SetBindGroup(groupIndex, group);
+ mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
+ dynamicOffsets);
return {};
},
@@ -226,4 +432,29 @@
mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage());
}
+ void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) {
+ // Encode commands for the backend to restore the pipeline and bind groups.
+ if (state.HasPipeline()) {
+ APISetPipeline(state.GetComputePipeline());
+ }
+ for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) {
+ BindGroupBase* bg = state.GetBindGroup(i);
+ if (bg != nullptr) {
+ const std::vector<uint32_t>& offsets = state.GetDynamicOffsets(i);
+ if (offsets.empty()) {
+ APISetBindGroup(static_cast<uint32_t>(i), bg);
+ } else {
+ APISetBindGroup(static_cast<uint32_t>(i), bg, offsets.size(), offsets.data());
+ }
+ }
+ }
+
+ // Restore the frontend state tracking information.
+ mCommandBufferState = std::move(state);
+ }
+
+ CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() {
+ return &mCommandBufferState;
+ }
+
} // namespace dawn_native
diff --git a/src/dawn_native/ComputePassEncoder.h b/src/dawn_native/ComputePassEncoder.h
index b0962f4..03997ce 100644
--- a/src/dawn_native/ComputePassEncoder.h
+++ b/src/dawn_native/ComputePassEncoder.h
@@ -50,6 +50,11 @@
void APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex);
+ CommandBufferStateTracker* GetCommandBufferStateTrackerForTesting();
+ void RestoreCommandBufferStateForTesting(CommandBufferStateTracker state) {
+ RestoreCommandBufferState(std::move(state));
+ }
+
protected:
ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
@@ -57,6 +62,12 @@
ErrorTag errorTag);
private:
+ ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch(
+ BufferBase* indirectBuffer,
+ uint64_t indirectOffset);
+
+ void RestoreCommandBufferState(CommandBufferStateTracker state);
+
CommandBufferStateTracker mCommandBufferState;
// Adds the bindgroups used for the current dispatch to the SyncScopeResourceUsage and
diff --git a/src/dawn_native/InternalPipelineStore.h b/src/dawn_native/InternalPipelineStore.h
index acf3b13..803e0df 100644
--- a/src/dawn_native/InternalPipelineStore.h
+++ b/src/dawn_native/InternalPipelineStore.h
@@ -52,6 +52,7 @@
Ref<ComputePipelineBase> renderValidationPipeline;
Ref<ShaderModuleBase> renderValidationShader;
+ Ref<ComputePipelineBase> dispatchIndirectValidationPipeline;
};
} // namespace dawn_native
diff --git a/src/dawn_native/RenderEncoderBase.cpp b/src/dawn_native/RenderEncoderBase.cpp
index 0445a97..a8ef2ff 100644
--- a/src/dawn_native/RenderEncoderBase.cpp
+++ b/src/dawn_native/RenderEncoderBase.cpp
@@ -208,6 +208,9 @@
BufferLocation::New(indirectBuffer, indirectOffset);
}
+ // TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage
+ // validation, but it will unecessarily transition to indirectBuffer usage in the
+ // backend.
mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
return {};
@@ -404,7 +407,8 @@
RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
dynamicOffsets);
- mCommandBufferState.SetBindGroup(groupIndex, group);
+ mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
+ dynamicOffsets);
mUsageTracker.AddBindGroup(group);
return {};
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index c7f4b7a..8ffd921 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -221,6 +221,7 @@
"unittests/SystemUtilsTests.cpp",
"unittests/ToBackendTests.cpp",
"unittests/TypedIntegerTests.cpp",
+ "unittests/native/CommandBufferEncodingTests.cpp",
"unittests/native/DestroyObjectTests.cpp",
"unittests/validation/BindGroupValidationTests.cpp",
"unittests/validation/BufferValidationTests.cpp",
diff --git a/src/tests/DawnNativeTest.cpp b/src/tests/DawnNativeTest.cpp
index d39c8e0..28d69bf 100644
--- a/src/tests/DawnNativeTest.cpp
+++ b/src/tests/DawnNativeTest.cpp
@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <gtest/gtest.h>
+#include "tests/DawnNativeTest.h"
#include "absl/strings/str_cat.h"
+#include "common/Assert.h"
+#include "dawn/dawn_proc.h"
#include "dawn_native/ErrorData.h"
namespace dawn_native {
@@ -28,3 +30,54 @@
}
} // namespace dawn_native
+
+DawnNativeTest::DawnNativeTest() {
+ dawnProcSetProcs(&dawn_native::GetProcs());
+}
+
+DawnNativeTest::~DawnNativeTest() {
+ device = wgpu::Device();
+ dawnProcSetProcs(nullptr);
+}
+
+void DawnNativeTest::SetUp() {
+ instance = std::make_unique<dawn_native::Instance>();
+ instance->DiscoverDefaultAdapters();
+
+ std::vector<dawn_native::Adapter> adapters = instance->GetAdapters();
+
+ // DawnNative unittests run against the null backend, find the corresponding adapter
+ bool foundNullAdapter = false;
+ for (auto& currentAdapter : adapters) {
+ wgpu::AdapterProperties adapterProperties;
+ currentAdapter.GetProperties(&adapterProperties);
+
+ if (adapterProperties.backendType == wgpu::BackendType::Null) {
+ adapter = currentAdapter;
+ foundNullAdapter = true;
+ break;
+ }
+ }
+
+ ASSERT(foundNullAdapter);
+
+ device = wgpu::Device(CreateTestDevice());
+ device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
+}
+
+void DawnNativeTest::TearDown() {
+}
+
+WGPUDevice DawnNativeTest::CreateTestDevice() {
+ // Disabled disallowing unsafe APIs so we can test them.
+ dawn_native::DeviceDescriptor deviceDescriptor;
+ deviceDescriptor.forceDisabledToggles.push_back("disallow_unsafe_apis");
+
+ return adapter.CreateDevice(&deviceDescriptor);
+}
+
+// static
+void DawnNativeTest::OnDeviceError(WGPUErrorType type, const char* message, void* userdata) {
+ ASSERT(type != WGPUErrorType_NoError);
+ FAIL() << "Unexpected error: " << message;
+}
diff --git a/src/tests/DawnNativeTest.h b/src/tests/DawnNativeTest.h
index 94fdafb..91904a3 100644
--- a/src/tests/DawnNativeTest.h
+++ b/src/tests/DawnNativeTest.h
@@ -17,6 +17,8 @@
#include <gtest/gtest.h>
+#include "dawn/webgpu_cpp.h"
+#include "dawn_native/DawnNative.h"
#include "dawn_native/ErrorData.h"
namespace dawn_native {
@@ -29,4 +31,23 @@
} // namespace dawn_native
+class DawnNativeTest : public ::testing::Test {
+ public:
+ DawnNativeTest();
+ ~DawnNativeTest() override;
+
+ void SetUp() override;
+ void TearDown() override;
+
+ virtual WGPUDevice CreateTestDevice();
+
+ protected:
+ std::unique_ptr<dawn_native::Instance> instance;
+ dawn_native::Adapter adapter;
+ wgpu::Device device;
+
+ private:
+ static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata);
+};
+
#endif // TESTS_DAWNNATIVETEST_H_
diff --git a/src/tests/end2end/ComputeDispatchTests.cpp b/src/tests/end2end/ComputeDispatchTests.cpp
index 1a8b163..cbc2d86 100644
--- a/src/tests/end2end/ComputeDispatchTests.cpp
+++ b/src/tests/end2end/ComputeDispatchTests.cpp
@@ -158,8 +158,14 @@
queue.Submit(1, &commands);
std::vector<uint32_t> expected;
+
+ uint32_t maxComputeWorkgroupsPerDimension =
+ GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
if (indirectBufferData[indirectStart] == 0 || indirectBufferData[indirectStart + 1] == 0 ||
- indirectBufferData[indirectStart + 2] == 0) {
+ indirectBufferData[indirectStart + 2] == 0 ||
+ indirectBufferData[indirectStart] > maxComputeWorkgroupsPerDimension ||
+ indirectBufferData[indirectStart + 1] > maxComputeWorkgroupsPerDimension ||
+ indirectBufferData[indirectStart + 2] > maxComputeWorkgroupsPerDimension) {
expected = kSentinelData;
} else {
expected.assign(indirectBufferData.begin() + indirectStart,
@@ -221,6 +227,52 @@
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t));
}
+// Test indirect dispatches at max limit.
+TEST_P(ComputeDispatchTests, MaxWorkgroups) {
+ // TODO(crbug.com/dawn/1165): Fails with WARP
+ DAWN_SUPPRESS_TEST_IF(IsWARP());
+ uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
+
+ // Test that the maximum works in each dimension.
+ // Note: Testing (max, max, max) is very slow.
+ IndirectTest({max, 3, 4}, 0);
+ IndirectTest({2, max, 4}, 0);
+ IndirectTest({2, 3, max}, 0);
+}
+
+// Test indirect dispatches exceeding the max limit are noop-ed.
+TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
+ DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
+
+ uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
+
+ // All dimensions are above the max
+ IndirectTest({max + 1, max + 1, max + 1}, 0);
+
+ // Only x dimension is above the max
+ IndirectTest({max + 1, 3, 4}, 0);
+ IndirectTest({2 * max, 3, 4}, 0);
+
+ // Only y dimension is above the max
+ IndirectTest({2, max + 1, 4}, 0);
+ IndirectTest({2, 2 * max, 4}, 0);
+
+ // Only z dimension is above the max
+ IndirectTest({2, 3, max + 1}, 0);
+ IndirectTest({2, 3, 2 * max}, 0);
+}
+
+// Test indirect dispatches exceeding the max limit with an offset are noop-ed.
+TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) {
+ DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
+
+ uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
+
+ IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));
+ IndirectTest({1, 2, 3, max + 1, 4, 5}, 2 * sizeof(uint32_t));
+ IndirectTest({1, 2, 3, max + 1, 4, 5}, 3 * sizeof(uint32_t));
+}
+
DAWN_INSTANTIATE_TEST(ComputeDispatchTests,
D3D12Backend(),
MetalBackend(),
diff --git a/src/tests/unittests/native/CommandBufferEncodingTests.cpp b/src/tests/unittests/native/CommandBufferEncodingTests.cpp
new file mode 100644
index 0000000..c1ca2d9
--- /dev/null
+++ b/src/tests/unittests/native/CommandBufferEncodingTests.cpp
@@ -0,0 +1,310 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/DawnNativeTest.h"
+
+#include "dawn_native/CommandBuffer.h"
+#include "dawn_native/Commands.h"
+#include "dawn_native/ComputePassEncoder.h"
+#include "utils/WGPUHelpers.h"
+
+class CommandBufferEncodingTests : public DawnNativeTest {
+ protected:
+ void ExpectCommands(dawn_native::CommandIterator* commands,
+ std::vector<std::pair<dawn_native::Command,
+ std::function<void(dawn_native::CommandIterator*)>>>
+ expectedCommands) {
+ dawn_native::Command commandId;
+ for (uint32_t commandIndex = 0; commands->NextCommandId(&commandId); ++commandIndex) {
+ ASSERT_LT(commandIndex, expectedCommands.size()) << "Unexpected command";
+ ASSERT_EQ(commandId, expectedCommands[commandIndex].first)
+ << "at command " << commandIndex;
+ expectedCommands[commandIndex].second(commands);
+ }
+ }
+};
+
+// Indirect dispatch validation changes the bind groups in the middle
+// of a pass. Test that bindings are restored after the validation runs.
+TEST_F(CommandBufferEncodingTests, ComputePassEncoderIndirectDispatchStateRestoration) {
+ using namespace dawn_native;
+
+ wgpu::BindGroupLayout staticLayout =
+ utils::MakeBindGroupLayout(device, {{
+ 0,
+ wgpu::ShaderStage::Compute,
+ wgpu::BufferBindingType::Uniform,
+ }});
+
+ wgpu::BindGroupLayout dynamicLayout =
+ utils::MakeBindGroupLayout(device, {{
+ 0,
+ wgpu::ShaderStage::Compute,
+ wgpu::BufferBindingType::Uniform,
+ true,
+ }});
+
+ // Create a simple pipeline
+ wgpu::ComputePipelineDescriptor csDesc;
+ csDesc.compute.module = utils::CreateShaderModule(device, R"(
+ [[stage(compute), workgroup_size(1, 1, 1)]]
+ fn main() {
+ })");
+ csDesc.compute.entryPoint = "main";
+
+ wgpu::PipelineLayout pl0 = utils::MakePipelineLayout(device, {staticLayout, dynamicLayout});
+ csDesc.layout = pl0;
+ wgpu::ComputePipeline pipeline0 = device.CreateComputePipeline(&csDesc);
+
+ wgpu::PipelineLayout pl1 = utils::MakePipelineLayout(device, {dynamicLayout, staticLayout});
+ csDesc.layout = pl1;
+ wgpu::ComputePipeline pipeline1 = device.CreateComputePipeline(&csDesc);
+
+ // Create buffers to use for both the indirect buffer and the bind groups.
+ wgpu::Buffer indirectBuffer =
+ utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {1, 2, 3, 4});
+
+ wgpu::BufferDescriptor uniformBufferDesc = {};
+ uniformBufferDesc.size = 512;
+ uniformBufferDesc.usage = wgpu::BufferUsage::Uniform;
+ wgpu::Buffer uniformBuffer = device.CreateBuffer(&uniformBufferDesc);
+
+ wgpu::BindGroup staticBG = utils::MakeBindGroup(device, staticLayout, {{0, uniformBuffer}});
+
+ wgpu::BindGroup dynamicBG =
+ utils::MakeBindGroup(device, dynamicLayout, {{0, uniformBuffer, 0, 256}});
+
+ uint32_t dynamicOffset = 256;
+ std::vector<uint32_t> emptyDynamicOffsets = {};
+ std::vector<uint32_t> singleDynamicOffset = {dynamicOffset};
+
+ // Begin encoding commands.
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+
+ CommandBufferStateTracker* stateTracker =
+ FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
+
+ // Perform a dispatch indirect which will be preceded by a validation dispatch.
+ pass.SetPipeline(pipeline0);
+ pass.SetBindGroup(0, staticBG);
+ pass.SetBindGroup(1, dynamicBG, 1, &dynamicOffset);
+ EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
+
+ pass.DispatchIndirect(indirectBuffer, 0);
+
+ // Expect restored state.
+ EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
+ EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
+ EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
+ EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
+ EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
+ EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
+
+ // Dispatch again to check that the restored state can be used.
+ // Also pass an indirect offset which should get replaced with the offset
+ // into the scratch indirect buffer (0).
+ pass.DispatchIndirect(indirectBuffer, 4);
+
+ // Expect restored state.
+ EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
+ EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
+ EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
+ EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
+ EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
+ EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
+
+ // Change the pipeline
+ pass.SetPipeline(pipeline1);
+ pass.SetBindGroup(0, dynamicBG, 1, &dynamicOffset);
+ pass.SetBindGroup(1, staticBG);
+ EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
+ EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
+
+ pass.DispatchIndirect(indirectBuffer, 0);
+
+ // Expect restored state.
+ EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
+ EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
+ EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), dynamicBG.Get());
+ EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), singleDynamicOffset);
+ EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), staticBG.Get());
+ EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), emptyDynamicOffsets);
+
+ pass.EndPass();
+
+ wgpu::CommandBuffer commandBuffer = encoder.Finish();
+
+ auto ExpectSetPipeline = [](wgpu::ComputePipeline pipeline) {
+ return [pipeline](CommandIterator* commands) {
+ auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
+ EXPECT_EQ(ToAPI(cmd->pipeline.Get()), pipeline.Get());
+ };
+ };
+
+ auto ExpectSetBindGroup = [](uint32_t index, wgpu::BindGroup bg,
+ std::vector<uint32_t> offsets = {}) {
+ return [index, bg, offsets](CommandIterator* commands) {
+ auto* cmd = commands->NextCommand<SetBindGroupCmd>();
+ uint32_t* dynamicOffsets = nullptr;
+ if (cmd->dynamicOffsetCount > 0) {
+ dynamicOffsets = commands->NextData<uint32_t>(cmd->dynamicOffsetCount);
+ }
+
+ ASSERT_EQ(cmd->index, BindGroupIndex(index));
+ ASSERT_EQ(ToAPI(cmd->group.Get()), bg.Get());
+ ASSERT_EQ(cmd->dynamicOffsetCount, offsets.size());
+ for (uint32_t i = 0; i < cmd->dynamicOffsetCount; ++i) {
+ ASSERT_EQ(dynamicOffsets[i], offsets[i]);
+ }
+ };
+ };
+
+ // Initialize as null. Once we know the pointer, we'll check
+ // that it's the same buffer every time.
+ WGPUBuffer indirectScratchBuffer = nullptr;
+ auto ExpectDispatchIndirect = [&](CommandIterator* commands) {
+ auto* cmd = commands->NextCommand<DispatchIndirectCmd>();
+ if (indirectScratchBuffer == nullptr) {
+ indirectScratchBuffer = ToAPI(cmd->indirectBuffer.Get());
+ }
+ ASSERT_EQ(ToAPI(cmd->indirectBuffer.Get()), indirectScratchBuffer);
+ ASSERT_EQ(cmd->indirectOffset, uint64_t(0));
+ };
+
+ // Initialize as null. Once we know the pointer, we'll check
+ // that it's the same pipeline every time.
+ WGPUComputePipeline validationPipeline = nullptr;
+ auto ExpectSetValidationPipeline = [&](CommandIterator* commands) {
+ auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
+ WGPUComputePipeline pipeline = ToAPI(cmd->pipeline.Get());
+ if (validationPipeline != nullptr) {
+ EXPECT_EQ(pipeline, validationPipeline);
+ } else {
+ EXPECT_NE(pipeline, nullptr);
+ validationPipeline = pipeline;
+ }
+ };
+
+ auto ExpectSetValidationBindGroup = [&](CommandIterator* commands) {
+ auto* cmd = commands->NextCommand<SetBindGroupCmd>();
+ ASSERT_EQ(cmd->index, BindGroupIndex(0));
+ ASSERT_NE(cmd->group.Get(), nullptr);
+ ASSERT_EQ(cmd->dynamicOffsetCount, 0u);
+ };
+
+ auto ExpectSetValidationDispatch = [&](CommandIterator* commands) {
+ auto* cmd = commands->NextCommand<DispatchCmd>();
+ ASSERT_EQ(cmd->x, 1u);
+ ASSERT_EQ(cmd->y, 1u);
+ ASSERT_EQ(cmd->z, 1u);
+ };
+
+ ExpectCommands(
+ FromAPI(commandBuffer.Get())->GetCommandIteratorForTesting(),
+ {
+ {Command::BeginComputePass,
+ [&](CommandIterator* commands) { SkipCommand(commands, Command::BeginComputePass); }},
+ // Expect the state to be set.
+ {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
+ {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
+ {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
+
+ // Expect the validation.
+ {Command::SetComputePipeline, ExpectSetValidationPipeline},
+ {Command::SetBindGroup, ExpectSetValidationBindGroup},
+ {Command::Dispatch, ExpectSetValidationDispatch},
+
+ // Expect the state to be restored.
+ {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
+ {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
+ {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
+
+ // Expect the dispatchIndirect.
+ {Command::DispatchIndirect, ExpectDispatchIndirect},
+
+ // Expect the validation.
+ {Command::SetComputePipeline, ExpectSetValidationPipeline},
+ {Command::SetBindGroup, ExpectSetValidationBindGroup},
+ {Command::Dispatch, ExpectSetValidationDispatch},
+
+ // Expect the state to be restored.
+ {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
+ {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
+ {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
+
+ // Expect the dispatchIndirect.
+ {Command::DispatchIndirect, ExpectDispatchIndirect},
+
+ // Expect the state to be set (new pipeline).
+ {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
+ {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
+ {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
+
+ // Expect the validation.
+ {Command::SetComputePipeline, ExpectSetValidationPipeline},
+ {Command::SetBindGroup, ExpectSetValidationBindGroup},
+ {Command::Dispatch, ExpectSetValidationDispatch},
+
+ // Expect the state to be restored.
+ {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
+ {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
+ {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
+
+ // Expect the dispatchIndirect.
+ {Command::DispatchIndirect, ExpectDispatchIndirect},
+
+ {Command::EndComputePass,
+ [&](CommandIterator* commands) { commands->NextCommand<EndComputePassCmd>(); }},
+ });
+}
+
+// Test that after restoring state, it is fully applied to the state tracker
+// and does not leak state changes that occured between a snapshot and the
+// state restoration.
+TEST_F(CommandBufferEncodingTests, StateNotLeakedAfterRestore) {
+ using namespace dawn_native;
+
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+
+ CommandBufferStateTracker* stateTracker =
+ FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
+
+ // Snapshot the state.
+ CommandBufferStateTracker snapshot = *stateTracker;
+ // Expect no pipeline in the snapshot
+ EXPECT_FALSE(snapshot.HasPipeline());
+
+ // Create a simple pipeline
+ wgpu::ComputePipelineDescriptor csDesc;
+ csDesc.compute.module = utils::CreateShaderModule(device, R"(
+ [[stage(compute), workgroup_size(1, 1, 1)]]
+ fn main() {
+ })");
+ csDesc.compute.entryPoint = "main";
+ wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
+
+ // Set the pipeline.
+ pass.SetPipeline(pipeline);
+
+ // Expect the pipeline to be set.
+ EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline.Get());
+
+ // Restore the state.
+ FromAPI(pass.Get())->RestoreCommandBufferStateForTesting(std::move(snapshot));
+
+ // Expect no pipeline
+ EXPECT_FALSE(stateTracker->HasPipeline());
+}