Metal: Add CommandRecordingContext
Introduces the idea of a CommandRecordingContext to the Metal backend,
similar to other backends. This is a class to track which Metal encoder
is open on the device-global pending MTLCommandBuffer.
It will be needed to open/close encoders for lazy clearing.
Bug: dawn:145
Change-Id: Ief6b71a079d73943677d2b61382d1c36b88a4f87
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/14780
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 0329c0b..5287bee 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -355,6 +355,8 @@
"src/dawn_native/metal/BufferMTL.mm",
"src/dawn_native/metal/CommandBufferMTL.h",
"src/dawn_native/metal/CommandBufferMTL.mm",
+ "src/dawn_native/metal/CommandRecordingContext.h",
+ "src/dawn_native/metal/CommandRecordingContext.mm",
"src/dawn_native/metal/ComputePipelineMTL.h",
"src/dawn_native/metal/ComputePipelineMTL.mm",
"src/dawn_native/metal/DeviceMTL.h",
diff --git a/src/dawn_native/metal/CommandBufferMTL.h b/src/dawn_native/metal/CommandBufferMTL.h
index 640d196..67a1313 100644
--- a/src/dawn_native/metal/CommandBufferMTL.h
+++ b/src/dawn_native/metal/CommandBufferMTL.h
@@ -26,25 +26,24 @@
namespace dawn_native { namespace metal {
+ class CommandRecordingContext;
class Device;
- struct GlobalEncoders;
class CommandBuffer : public CommandBufferBase {
public:
CommandBuffer(CommandEncoder* encoder, const CommandBufferDescriptor* descriptor);
~CommandBuffer();
- void FillCommands(id<MTLCommandBuffer> commandBuffer);
+ void FillCommands(CommandRecordingContext* commandContext);
private:
- void EncodeComputePass(id<MTLCommandBuffer> commandBuffer);
- void EncodeRenderPass(id<MTLCommandBuffer> commandBuffer,
+ void EncodeComputePass(CommandRecordingContext* commandContext);
+ void EncodeRenderPass(CommandRecordingContext* commandContext,
MTLRenderPassDescriptor* mtlRenderPass,
- GlobalEncoders* globalEncoders,
uint32_t width,
uint32_t height);
- void EncodeRenderPassInternal(id<MTLCommandBuffer> commandBuffer,
+ void EncodeRenderPassInternal(CommandRecordingContext* commandContext,
MTLRenderPassDescriptor* mtlRenderPass,
uint32_t width,
uint32_t height);
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 77596dd..866a6fe 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -29,23 +29,6 @@
namespace dawn_native { namespace metal {
- struct GlobalEncoders {
- id<MTLBlitCommandEncoder> blit = nil;
-
- void Finish() {
- if (blit != nil) {
- [blit endEncoding];
- blit = nil; // This will be autoreleased.
- }
- }
-
- void EnsureBlit(id<MTLCommandBuffer> commandBuffer) {
- if (blit == nil) {
- blit = [commandBuffer blitCommandEncoder];
- }
- }
- };
-
namespace {
// Allows this file to use MTLStoreActionStoreAndMultismapleResolve because the logic is
@@ -133,7 +116,7 @@
// Helper function for Toggle EmulateStoreAndMSAAResolve
void ResolveInAnotherRenderPass(
- id<MTLCommandBuffer> commandBuffer,
+ CommandRecordingContext* commandContext,
const MTLRenderPassDescriptor* mtlRenderPass,
const std::array<id<MTLTexture>, kMaxColorAttachments>& resolveTextures) {
MTLRenderPassDescriptor* mtlRenderPassForResolve =
@@ -155,9 +138,8 @@
mtlRenderPass.colorAttachments[i].resolveSlice;
}
- id<MTLRenderCommandEncoder> encoder =
- [commandBuffer renderCommandEncoderWithDescriptor:mtlRenderPassForResolve];
- [encoder endEncoding];
+ commandContext->BeginRender(mtlRenderPassForResolve);
+ commandContext->EndRender();
}
// Helper functions for Toggle AlwaysResolveIntoZeroLevelAndLayer
@@ -182,24 +164,22 @@
return resolveTexture;
}
- void CopyIntoTrueResolveTarget(id<MTLCommandBuffer> commandBuffer,
+ void CopyIntoTrueResolveTarget(CommandRecordingContext* commandContext,
id<MTLTexture> mtlTrueResolveTexture,
uint32_t trueResolveLevel,
uint32_t trueResolveSlice,
id<MTLTexture> temporaryResolveTexture,
uint32_t width,
- uint32_t height,
- GlobalEncoders* encoders) {
- encoders->EnsureBlit(commandBuffer);
- [encoders->blit copyFromTexture:temporaryResolveTexture
- sourceSlice:0
- sourceLevel:0
- sourceOrigin:MTLOriginMake(0, 0, 0)
- sourceSize:MTLSizeMake(width, height, 1)
- toTexture:mtlTrueResolveTexture
- destinationSlice:trueResolveSlice
- destinationLevel:trueResolveLevel
- destinationOrigin:MTLOriginMake(0, 0, 0)];
+ uint32_t height) {
+ [commandContext->EnsureBlit() copyFromTexture:temporaryResolveTexture
+ sourceSlice:0
+ sourceLevel:0
+ sourceOrigin:MTLOriginMake(0, 0, 0)
+ sourceSize:MTLSizeMake(width, height, 1)
+ toTexture:mtlTrueResolveTexture
+ destinationSlice:trueResolveSlice
+ destinationLevel:trueResolveLevel
+ destinationOrigin:MTLOriginMake(0, 0, 0)];
}
// Metal uses a physical addressing mode which means buffers in the shading language are
@@ -608,34 +588,33 @@
FreeCommands(&mCommands);
}
- void CommandBuffer::FillCommands(id<MTLCommandBuffer> commandBuffer) {
- GlobalEncoders encoders;
-
+ void CommandBuffer::FillCommands(CommandRecordingContext* commandContext) {
Command type;
while (mCommands.NextCommandId(&type)) {
switch (type) {
case Command::BeginComputePass: {
mCommands.NextCommand<BeginComputePassCmd>();
- encoders.Finish();
- EncodeComputePass(commandBuffer);
+
+ commandContext->EndBlit();
+ EncodeComputePass(commandContext);
} break;
case Command::BeginRenderPass: {
BeginRenderPassCmd* cmd = mCommands.NextCommand<BeginRenderPassCmd>();
- encoders.Finish();
+ commandContext->EndBlit();
MTLRenderPassDescriptor* descriptor = CreateMTLRenderPassDescriptor(cmd);
- EncodeRenderPass(commandBuffer, descriptor, &encoders, cmd->width, cmd->height);
+ EncodeRenderPass(commandContext, descriptor, cmd->width, cmd->height);
} break;
case Command::CopyBufferToBuffer: {
CopyBufferToBufferCmd* copy = mCommands.NextCommand<CopyBufferToBufferCmd>();
- encoders.EnsureBlit(commandBuffer);
- [encoders.blit copyFromBuffer:ToBackend(copy->source)->GetMTLBuffer()
- sourceOffset:copy->sourceOffset
- toBuffer:ToBackend(copy->destination)->GetMTLBuffer()
- destinationOffset:copy->destinationOffset
- size:copy->size];
+ [commandContext->EnsureBlit()
+ copyFromBuffer:ToBackend(copy->source)->GetMTLBuffer()
+ sourceOffset:copy->sourceOffset
+ toBuffer:ToBackend(copy->destination)->GetMTLBuffer()
+ destinationOffset:copy->destinationOffset
+ size:copy->size];
} break;
case Command::CopyBufferToTexture: {
@@ -651,18 +630,17 @@
dst.origin, copySize, texture->GetFormat(), virtualSizeAtLevel,
buffer->GetSize(), src.offset, src.rowPitch, src.imageHeight);
- encoders.EnsureBlit(commandBuffer);
for (uint32_t i = 0; i < splittedCopies.count; ++i) {
const TextureBufferCopySplit::CopyInfo& copyInfo = splittedCopies.copies[i];
- [encoders.blit copyFromBuffer:buffer->GetMTLBuffer()
- sourceOffset:copyInfo.bufferOffset
- sourceBytesPerRow:copyInfo.bytesPerRow
- sourceBytesPerImage:copyInfo.bytesPerImage
- sourceSize:copyInfo.copyExtent
- toTexture:texture->GetMTLTexture()
- destinationSlice:dst.arrayLayer
- destinationLevel:dst.mipLevel
- destinationOrigin:copyInfo.textureOrigin];
+ [commandContext->EnsureBlit() copyFromBuffer:buffer->GetMTLBuffer()
+ sourceOffset:copyInfo.bufferOffset
+ sourceBytesPerRow:copyInfo.bytesPerRow
+ sourceBytesPerImage:copyInfo.bytesPerImage
+ sourceSize:copyInfo.copyExtent
+ toTexture:texture->GetMTLTexture()
+ destinationSlice:dst.arrayLayer
+ destinationLevel:dst.mipLevel
+ destinationOrigin:copyInfo.textureOrigin];
}
} break;
@@ -679,18 +657,17 @@
src.origin, copySize, texture->GetFormat(), virtualSizeAtLevel,
buffer->GetSize(), dst.offset, dst.rowPitch, dst.imageHeight);
- encoders.EnsureBlit(commandBuffer);
for (uint32_t i = 0; i < splittedCopies.count; ++i) {
const TextureBufferCopySplit::CopyInfo& copyInfo = splittedCopies.copies[i];
- [encoders.blit copyFromTexture:texture->GetMTLTexture()
- sourceSlice:src.arrayLayer
- sourceLevel:src.mipLevel
- sourceOrigin:copyInfo.textureOrigin
- sourceSize:copyInfo.copyExtent
- toBuffer:buffer->GetMTLBuffer()
- destinationOffset:copyInfo.bufferOffset
- destinationBytesPerRow:copyInfo.bytesPerRow
- destinationBytesPerImage:copyInfo.bytesPerImage];
+ [commandContext->EnsureBlit() copyFromTexture:texture->GetMTLTexture()
+ sourceSlice:src.arrayLayer
+ sourceLevel:src.mipLevel
+ sourceOrigin:copyInfo.textureOrigin
+ sourceSize:copyInfo.copyExtent
+ toBuffer:buffer->GetMTLBuffer()
+ destinationOffset:copyInfo.bufferOffset
+ destinationBytesPerRow:copyInfo.bytesPerRow
+ destinationBytesPerImage:copyInfo.bytesPerImage];
}
} break;
@@ -700,40 +677,38 @@
Texture* srcTexture = ToBackend(copy->source.texture.Get());
Texture* dstTexture = ToBackend(copy->destination.texture.Get());
- encoders.EnsureBlit(commandBuffer);
-
- [encoders.blit copyFromTexture:srcTexture->GetMTLTexture()
- sourceSlice:copy->source.arrayLayer
- sourceLevel:copy->source.mipLevel
- sourceOrigin:MakeMTLOrigin(copy->source.origin)
- sourceSize:MakeMTLSize(copy->copySize)
- toTexture:dstTexture->GetMTLTexture()
- destinationSlice:copy->destination.arrayLayer
- destinationLevel:copy->destination.mipLevel
- destinationOrigin:MakeMTLOrigin(copy->destination.origin)];
+ [commandContext->EnsureBlit()
+ copyFromTexture:srcTexture->GetMTLTexture()
+ sourceSlice:copy->source.arrayLayer
+ sourceLevel:copy->source.mipLevel
+ sourceOrigin:MakeMTLOrigin(copy->source.origin)
+ sourceSize:MakeMTLSize(copy->copySize)
+ toTexture:dstTexture->GetMTLTexture()
+ destinationSlice:copy->destination.arrayLayer
+ destinationLevel:copy->destination.mipLevel
+ destinationOrigin:MakeMTLOrigin(copy->destination.origin)];
} break;
default: { UNREACHABLE(); } break;
}
}
- encoders.Finish();
+ commandContext->EndBlit();
}
- void CommandBuffer::EncodeComputePass(id<MTLCommandBuffer> commandBuffer) {
+ void CommandBuffer::EncodeComputePass(CommandRecordingContext* commandContext) {
ComputePipeline* lastPipeline = nullptr;
StorageBufferLengthTracker storageBufferLengths = {};
BindGroupTracker bindGroups(&storageBufferLengths);
- // Will be autoreleased
- id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
+ id<MTLComputeCommandEncoder> encoder = commandContext->BeginCompute();
Command type;
while (mCommands.NextCommandId(&type)) {
switch (type) {
case Command::EndComputePass: {
mCommands.NextCommand<EndComputePassCmd>();
- [encoder endEncoding];
+ commandContext->EndCompute();
return;
} break;
@@ -813,12 +788,11 @@
UNREACHABLE();
}
- void CommandBuffer::EncodeRenderPass(id<MTLCommandBuffer> commandBuffer,
+ void CommandBuffer::EncodeRenderPass(CommandRecordingContext* commandContext,
MTLRenderPassDescriptor* mtlRenderPass,
- GlobalEncoders* globalEncoders,
uint32_t width,
uint32_t height) {
- ASSERT(mtlRenderPass && globalEncoders);
+ ASSERT(mtlRenderPass);
Device* device = ToBackend(GetDevice());
@@ -861,17 +835,16 @@
// If we need to use a temporary resolve texture we need to copy the result of MSAA
// resolve back to the true resolve targets.
if (useTemporaryResolveTexture) {
- EncodeRenderPass(commandBuffer, mtlRenderPass, globalEncoders, width, height);
+ EncodeRenderPass(commandContext, mtlRenderPass, width, height);
for (uint32_t i = 0; i < kMaxColorAttachments; ++i) {
if (trueResolveTextures[i] == nil) {
continue;
}
ASSERT(temporaryResolveTextures[i] != nil);
- CopyIntoTrueResolveTarget(commandBuffer, trueResolveTextures[i],
+ CopyIntoTrueResolveTarget(commandContext, trueResolveTextures[i],
trueResolveLevels[i], trueResolveSlices[i],
- temporaryResolveTextures[i], width, height,
- globalEncoders);
+ temporaryResolveTextures[i], width, height);
}
return;
}
@@ -896,16 +869,16 @@
// If we found a store + MSAA resolve we need to resolve in a different render pass.
if (hasStoreAndMSAAResolve) {
- EncodeRenderPass(commandBuffer, mtlRenderPass, globalEncoders, width, height);
- ResolveInAnotherRenderPass(commandBuffer, mtlRenderPass, resolveTextures);
+ EncodeRenderPass(commandContext, mtlRenderPass, width, height);
+ ResolveInAnotherRenderPass(commandContext, mtlRenderPass, resolveTextures);
return;
}
}
- EncodeRenderPassInternal(commandBuffer, mtlRenderPass, width, height);
+ EncodeRenderPassInternal(commandContext, mtlRenderPass, width, height);
}
- void CommandBuffer::EncodeRenderPassInternal(id<MTLCommandBuffer> commandBuffer,
+ void CommandBuffer::EncodeRenderPassInternal(CommandRecordingContext* commandContext,
MTLRenderPassDescriptor* mtlRenderPass,
uint32_t width,
uint32_t height) {
@@ -916,9 +889,7 @@
StorageBufferLengthTracker storageBufferLengths = {};
BindGroupTracker bindGroups(&storageBufferLengths);
- // This will be autoreleased
- id<MTLRenderCommandEncoder> encoder =
- [commandBuffer renderCommandEncoderWithDescriptor:mtlRenderPass];
+ id<MTLRenderCommandEncoder> encoder = commandContext->BeginRender(mtlRenderPass);
auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
switch (type) {
@@ -1068,7 +1039,7 @@
switch (type) {
case Command::EndRenderPass: {
mCommands.NextCommand<EndRenderPassCmd>();
- [encoder endEncoding];
+ commandContext->EndRender();
return;
} break;
diff --git a/src/dawn_native/metal/CommandRecordingContext.h b/src/dawn_native/metal/CommandRecordingContext.h
new file mode 100644
index 0000000..531681b
--- /dev/null
+++ b/src/dawn_native/metal/CommandRecordingContext.h
@@ -0,0 +1,59 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#ifndef DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
+#define DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
+
+#import <Metal/Metal.h>
+
+namespace dawn_native { namespace metal {
+
+ // This class wraps a MTLCommandBuffer and tracks which Metal encoder is open.
+ // Only one encoder may be open at a time.
+ class CommandRecordingContext {
+ public:
+ CommandRecordingContext();
+ CommandRecordingContext(id<MTLCommandBuffer> commands);
+
+ CommandRecordingContext(const CommandRecordingContext& rhs) = delete;
+ CommandRecordingContext& operator=(const CommandRecordingContext& rhs) = delete;
+
+ CommandRecordingContext(CommandRecordingContext&& rhs);
+ CommandRecordingContext& operator=(CommandRecordingContext&& rhs);
+
+ ~CommandRecordingContext();
+
+ id<MTLCommandBuffer> GetCommands();
+
+ id<MTLCommandBuffer> AcquireCommands();
+
+ id<MTLBlitCommandEncoder> EnsureBlit();
+ void EndBlit();
+
+ id<MTLComputeCommandEncoder> BeginCompute();
+ void EndCompute();
+
+ id<MTLRenderCommandEncoder> BeginRender(MTLRenderPassDescriptor* descriptor);
+ void EndRender();
+
+ private:
+ id<MTLCommandBuffer> mCommands = nil;
+ id<MTLBlitCommandEncoder> mBlit = nil;
+ id<MTLComputeCommandEncoder> mCompute = nil;
+ id<MTLRenderCommandEncoder> mRender = nil;
+ bool mInEncoder = false;
+ };
+
+}} // namespace dawn_native::metal
+
+#endif // DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
diff --git a/src/dawn_native/metal/CommandRecordingContext.mm b/src/dawn_native/metal/CommandRecordingContext.mm
new file mode 100644
index 0000000..df4d6f8
--- /dev/null
+++ b/src/dawn_native/metal/CommandRecordingContext.mm
@@ -0,0 +1,113 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/metal/CommandRecordingContext.h"
+
+#include "common/Assert.h"
+
+namespace dawn_native { namespace metal {
+
+ CommandRecordingContext::CommandRecordingContext() = default;
+
+ CommandRecordingContext::CommandRecordingContext(id<MTLCommandBuffer> commands)
+ : mCommands(commands) {
+ }
+
+ CommandRecordingContext::CommandRecordingContext(CommandRecordingContext&& rhs)
+ : mCommands(rhs.AcquireCommands()) {
+ }
+
+ CommandRecordingContext& CommandRecordingContext::operator=(CommandRecordingContext&& rhs) {
+ mCommands = rhs.AcquireCommands();
+ return *this;
+ }
+
+ CommandRecordingContext::~CommandRecordingContext() {
+ // Commands must be acquired.
+ ASSERT(mCommands == nil);
+ }
+
+ id<MTLCommandBuffer> CommandRecordingContext::GetCommands() {
+ return mCommands;
+ }
+
+ id<MTLCommandBuffer> CommandRecordingContext::AcquireCommands() {
+ ASSERT(!mInEncoder);
+
+ id<MTLCommandBuffer> commands = mCommands;
+ mCommands = nil;
+ return commands;
+ }
+
+ id<MTLBlitCommandEncoder> CommandRecordingContext::EnsureBlit() {
+ ASSERT(mCommands != nil);
+
+ if (mBlit == nil) {
+ ASSERT(!mInEncoder);
+ mInEncoder = true;
+ mBlit = [mCommands blitCommandEncoder];
+ }
+ return mBlit;
+ }
+
+ void CommandRecordingContext::EndBlit() {
+ ASSERT(mCommands != nil);
+
+ if (mBlit != nil) {
+ [mBlit endEncoding];
+ mBlit = nil; // This will be autoreleased.
+ mInEncoder = false;
+ }
+ }
+
+ id<MTLComputeCommandEncoder> CommandRecordingContext::BeginCompute() {
+ ASSERT(mCommands != nil);
+ ASSERT(mCompute == nil);
+ ASSERT(!mInEncoder);
+
+ mInEncoder = true;
+ mCompute = [mCommands computeCommandEncoder];
+ return mCompute;
+ }
+
+ void CommandRecordingContext::EndCompute() {
+ ASSERT(mCommands != nil);
+ ASSERT(mCompute != nil);
+
+ [mCompute endEncoding];
+ mCompute = nil; // This will be autoreleased.
+ mInEncoder = false;
+ }
+
+ id<MTLRenderCommandEncoder> CommandRecordingContext::BeginRender(
+ MTLRenderPassDescriptor* descriptor) {
+ ASSERT(mCommands != nil);
+ ASSERT(mRender == nil);
+ ASSERT(!mInEncoder);
+
+ mInEncoder = true;
+ mRender = [mCommands renderCommandEncoderWithDescriptor:descriptor];
+ return mRender;
+ }
+
+ void CommandRecordingContext::EndRender() {
+ ASSERT(mCommands != nil);
+ ASSERT(mRender != nil);
+
+ [mRender endEncoding];
+ mRender = nil; // This will be autoreleased.
+ mInEncoder = false;
+ }
+
+}} // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index 2219ab6..667eb56 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -19,6 +19,7 @@
#include "common/Serial.h"
#include "dawn_native/Device.h"
+#include "dawn_native/metal/CommandRecordingContext.h"
#include "dawn_native/metal/Forward.h"
#import <IOSurface/IOSurfaceRef.h>
@@ -48,7 +49,7 @@
id<MTLDevice> GetMTLDevice();
id<MTLCommandQueue> GetMTLQueue();
- id<MTLCommandBuffer> GetPendingCommandBuffer();
+ CommandRecordingContext* GetPendingCommandContext();
Serial GetPendingCommandSerial() const override;
void SubmitPendingCommandBuffer();
@@ -98,7 +99,7 @@
std::unique_ptr<MapRequestTracker> mMapTracker;
Serial mLastSubmittedSerial = 0;
- id<MTLCommandBuffer> mPendingCommands = nil;
+ CommandRecordingContext mCommandContext;
// The completed serial is updated in a Metal completion handler that can be fired on a
// different thread, so it needs to be atomic.
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 6cd73a9..54b1a19 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -144,7 +144,7 @@
mDynamicUploader->Deallocate(completedSerial);
mMapTracker->Tick(completedSerial);
- if (mPendingCommands != nil) {
+ if (mCommandContext.GetCommands() != nil) {
SubmitPendingCommandBuffer();
} else if (completedSerial == mLastSubmittedSerial) {
// If there's no GPU work in flight we still need to artificially increment the serial
@@ -164,45 +164,43 @@
return mCommandQueue;
}
- id<MTLCommandBuffer> Device::GetPendingCommandBuffer() {
- TRACE_EVENT0(GetPlatform(), General, "DeviceMTL::GetPendingCommandBuffer");
- if (mPendingCommands == nil) {
- mPendingCommands = [mCommandQueue commandBuffer];
- [mPendingCommands retain];
+ CommandRecordingContext* Device::GetPendingCommandContext() {
+ if (mCommandContext.GetCommands() == nil) {
+ TRACE_EVENT0(GetPlatform(), General, "[MTLCommandQueue commandBuffer]");
+ mCommandContext = CommandRecordingContext([mCommandQueue commandBuffer]);
}
- return mPendingCommands;
+ return &mCommandContext;
}
void Device::SubmitPendingCommandBuffer() {
- if (mPendingCommands == nil) {
+ if (mCommandContext.GetCommands() == nil) {
return;
}
mLastSubmittedSerial++;
+ // Ensure the blit encoder is ended. It may have been opened to perform a lazy clear or
+ // buffer upload.
+ mCommandContext.EndBlit();
+
+ // Acquire and retain the pending commands. We must keep them alive until scheduled.
+ id<MTLCommandBuffer> pendingCommands = [mCommandContext.AcquireCommands() retain];
+
// Replace mLastSubmittedCommands with the mutex held so we avoid races between the
// schedule handler and this code.
{
std::lock_guard<std::mutex> lock(mLastSubmittedCommandsMutex);
[mLastSubmittedCommands release];
- mLastSubmittedCommands = mPendingCommands;
+ mLastSubmittedCommands = pendingCommands;
}
- // Ok, ObjC blocks are weird. My understanding is that local variables are captured by
- // value so this-> works as expected. However it is unclear how members are captured, (are
- // they captured using this-> or by value?). To be safe we copy members to local variables
- // to ensure they are captured "by value".
-
- // Free mLastSubmittedCommands as soon as it is scheduled so that it doesn't hold
- // references to its resources. Make a local copy of pendingCommands first so it is
- // captured "by-value" by the block.
- id<MTLCommandBuffer> pendingCommands = mPendingCommands;
-
- [mPendingCommands addScheduledHandler:^(id<MTLCommandBuffer>) {
+ [pendingCommands addScheduledHandler:^(id<MTLCommandBuffer>) {
// This is DRF because we hold the mutex for mLastSubmittedCommands and pendingCommands
// is a local value (and not the member itself).
std::lock_guard<std::mutex> lock(mLastSubmittedCommandsMutex);
if (this->mLastSubmittedCommands == pendingCommands) {
+ // Free mLastSubmittedCommands as soon as it is scheduled so that it doesn't hold
+ // references to its resources.
[this->mLastSubmittedCommands release];
this->mLastSubmittedCommands = nil;
}
@@ -211,7 +209,7 @@
// Update the completed serial once the completed handler is fired. Make a local copy of
// mLastSubmittedSerial so it is captured by value.
Serial pendingSerial = mLastSubmittedSerial;
- [mPendingCommands addCompletedHandler:^(id<MTLCommandBuffer>) {
+ [pendingCommands addCompletedHandler:^(id<MTLCommandBuffer>) {
TRACE_EVENT_ASYNC_END0(GetPlatform(), GPUWork, "DeviceMTL::SubmitPendingCommandBuffer",
pendingSerial);
ASSERT(pendingSerial > mCompletedSerial.load());
@@ -220,8 +218,7 @@
TRACE_EVENT_ASYNC_BEGIN0(GetPlatform(), GPUWork, "DeviceMTL::SubmitPendingCommandBuffer",
pendingSerial);
- [mPendingCommands commit];
- mPendingCommands = nil;
+ [pendingCommands commit];
}
MapRequestTracker* Device::GetMapTracker() const {
@@ -242,15 +239,11 @@
uint64_t size) {
id<MTLBuffer> uploadBuffer = ToBackend(source)->GetBufferHandle();
id<MTLBuffer> buffer = ToBackend(destination)->GetMTLBuffer();
- id<MTLCommandBuffer> commandBuffer = GetPendingCommandBuffer();
- id<MTLBlitCommandEncoder> encoder = [commandBuffer blitCommandEncoder];
- [encoder copyFromBuffer:uploadBuffer
- sourceOffset:sourceOffset
- toBuffer:buffer
- destinationOffset:destinationOffset
- size:size];
- [encoder endEncoding];
-
+ [GetPendingCommandContext()->EnsureBlit() copyFromBuffer:uploadBuffer
+ sourceOffset:sourceOffset
+ toBuffer:buffer
+ destinationOffset:destinationOffset
+ size:size];
return {};
}
@@ -273,8 +266,7 @@
}
MaybeError Device::WaitForIdleForDestruction() {
- [mPendingCommands release];
- mPendingCommands = nil;
+ [mCommandContext.AcquireCommands() release];
// Wait for all commands to be finished so we can free resources
while (GetCompletedCommandSerial() != mLastSubmittedSerial) {
@@ -285,10 +277,7 @@
}
void Device::Destroy() {
- if (mPendingCommands != nil) {
- [mPendingCommands release];
- mPendingCommands = nil;
- }
+ [mCommandContext.AcquireCommands() release];
mMapTracker = nullptr;
mDynamicUploader = nullptr;
diff --git a/src/dawn_native/metal/QueueMTL.mm b/src/dawn_native/metal/QueueMTL.mm
index dd360e9..7c5967a 100644
--- a/src/dawn_native/metal/QueueMTL.mm
+++ b/src/dawn_native/metal/QueueMTL.mm
@@ -27,11 +27,11 @@
MaybeError Queue::SubmitImpl(uint32_t commandCount, CommandBufferBase* const* commands) {
Device* device = ToBackend(GetDevice());
device->Tick();
- id<MTLCommandBuffer> commandBuffer = device->GetPendingCommandBuffer();
+ CommandRecordingContext* commandContext = device->GetPendingCommandContext();
TRACE_EVENT_BEGIN0(GetDevice()->GetPlatform(), Recording, "CommandBufferMTL::FillCommands");
for (uint32_t i = 0; i < commandCount; ++i) {
- ToBackend(commands[i])->FillCommands(commandBuffer);
+ ToBackend(commands[i])->FillCommands(commandContext);
}
TRACE_EVENT_END0(GetDevice()->GetPlatform(), Recording, "CommandBufferMTL::FillCommands");