Metal: Add CommandRecordingContext

Introduces the idea of a CommandRecordingContext to the Metal backend,
similar to other backends. This is a class to track which Metal encoder
is open on the device-global pending MTLCommandBuffer.
It will be needed to open/close encoders for lazy clearing.

Bug: dawn:145
Change-Id: Ief6b71a079d73943677d2b61382d1c36b88a4f87
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/14780
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 0329c0b..5287bee 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -355,6 +355,8 @@
       "src/dawn_native/metal/BufferMTL.mm",
       "src/dawn_native/metal/CommandBufferMTL.h",
       "src/dawn_native/metal/CommandBufferMTL.mm",
+      "src/dawn_native/metal/CommandRecordingContext.h",
+      "src/dawn_native/metal/CommandRecordingContext.mm",
       "src/dawn_native/metal/ComputePipelineMTL.h",
       "src/dawn_native/metal/ComputePipelineMTL.mm",
       "src/dawn_native/metal/DeviceMTL.h",
diff --git a/src/dawn_native/metal/CommandBufferMTL.h b/src/dawn_native/metal/CommandBufferMTL.h
index 640d196..67a1313 100644
--- a/src/dawn_native/metal/CommandBufferMTL.h
+++ b/src/dawn_native/metal/CommandBufferMTL.h
@@ -26,25 +26,24 @@
 
 namespace dawn_native { namespace metal {
 
+    class CommandRecordingContext;
     class Device;
-    struct GlobalEncoders;
 
     class CommandBuffer : public CommandBufferBase {
       public:
         CommandBuffer(CommandEncoder* encoder, const CommandBufferDescriptor* descriptor);
         ~CommandBuffer();
 
-        void FillCommands(id<MTLCommandBuffer> commandBuffer);
+        void FillCommands(CommandRecordingContext* commandContext);
 
       private:
-        void EncodeComputePass(id<MTLCommandBuffer> commandBuffer);
-        void EncodeRenderPass(id<MTLCommandBuffer> commandBuffer,
+        void EncodeComputePass(CommandRecordingContext* commandContext);
+        void EncodeRenderPass(CommandRecordingContext* commandContext,
                               MTLRenderPassDescriptor* mtlRenderPass,
-                              GlobalEncoders* globalEncoders,
                               uint32_t width,
                               uint32_t height);
 
-        void EncodeRenderPassInternal(id<MTLCommandBuffer> commandBuffer,
+        void EncodeRenderPassInternal(CommandRecordingContext* commandContext,
                                       MTLRenderPassDescriptor* mtlRenderPass,
                                       uint32_t width,
                                       uint32_t height);
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 77596dd..866a6fe 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -29,23 +29,6 @@
 
 namespace dawn_native { namespace metal {
 
-    struct GlobalEncoders {
-        id<MTLBlitCommandEncoder> blit = nil;
-
-        void Finish() {
-            if (blit != nil) {
-                [blit endEncoding];
-                blit = nil;  // This will be autoreleased.
-            }
-        }
-
-        void EnsureBlit(id<MTLCommandBuffer> commandBuffer) {
-            if (blit == nil) {
-                blit = [commandBuffer blitCommandEncoder];
-            }
-        }
-    };
-
     namespace {
 
         // Allows this file to use MTLStoreActionStoreAndMultismapleResolve because the logic is
@@ -133,7 +116,7 @@
 
         // Helper function for Toggle EmulateStoreAndMSAAResolve
         void ResolveInAnotherRenderPass(
-            id<MTLCommandBuffer> commandBuffer,
+            CommandRecordingContext* commandContext,
             const MTLRenderPassDescriptor* mtlRenderPass,
             const std::array<id<MTLTexture>, kMaxColorAttachments>& resolveTextures) {
             MTLRenderPassDescriptor* mtlRenderPassForResolve =
@@ -155,9 +138,8 @@
                     mtlRenderPass.colorAttachments[i].resolveSlice;
             }
 
-            id<MTLRenderCommandEncoder> encoder =
-                [commandBuffer renderCommandEncoderWithDescriptor:mtlRenderPassForResolve];
-            [encoder endEncoding];
+            commandContext->BeginRender(mtlRenderPassForResolve);
+            commandContext->EndRender();
         }
 
         // Helper functions for Toggle AlwaysResolveIntoZeroLevelAndLayer
@@ -182,24 +164,22 @@
             return resolveTexture;
         }
 
-        void CopyIntoTrueResolveTarget(id<MTLCommandBuffer> commandBuffer,
+        void CopyIntoTrueResolveTarget(CommandRecordingContext* commandContext,
                                        id<MTLTexture> mtlTrueResolveTexture,
                                        uint32_t trueResolveLevel,
                                        uint32_t trueResolveSlice,
                                        id<MTLTexture> temporaryResolveTexture,
                                        uint32_t width,
-                                       uint32_t height,
-                                       GlobalEncoders* encoders) {
-            encoders->EnsureBlit(commandBuffer);
-            [encoders->blit copyFromTexture:temporaryResolveTexture
-                                sourceSlice:0
-                                sourceLevel:0
-                               sourceOrigin:MTLOriginMake(0, 0, 0)
-                                 sourceSize:MTLSizeMake(width, height, 1)
-                                  toTexture:mtlTrueResolveTexture
-                           destinationSlice:trueResolveSlice
-                           destinationLevel:trueResolveLevel
-                          destinationOrigin:MTLOriginMake(0, 0, 0)];
+                                       uint32_t height) {
+            [commandContext->EnsureBlit() copyFromTexture:temporaryResolveTexture
+                                              sourceSlice:0
+                                              sourceLevel:0
+                                             sourceOrigin:MTLOriginMake(0, 0, 0)
+                                               sourceSize:MTLSizeMake(width, height, 1)
+                                                toTexture:mtlTrueResolveTexture
+                                         destinationSlice:trueResolveSlice
+                                         destinationLevel:trueResolveLevel
+                                        destinationOrigin:MTLOriginMake(0, 0, 0)];
         }
 
         // Metal uses a physical addressing mode which means buffers in the shading language are
@@ -608,34 +588,33 @@
         FreeCommands(&mCommands);
     }
 
-    void CommandBuffer::FillCommands(id<MTLCommandBuffer> commandBuffer) {
-        GlobalEncoders encoders;
-
+    void CommandBuffer::FillCommands(CommandRecordingContext* commandContext) {
         Command type;
         while (mCommands.NextCommandId(&type)) {
             switch (type) {
                 case Command::BeginComputePass: {
                     mCommands.NextCommand<BeginComputePassCmd>();
-                    encoders.Finish();
-                    EncodeComputePass(commandBuffer);
+
+                    commandContext->EndBlit();
+                    EncodeComputePass(commandContext);
                 } break;
 
                 case Command::BeginRenderPass: {
                     BeginRenderPassCmd* cmd = mCommands.NextCommand<BeginRenderPassCmd>();
-                    encoders.Finish();
+                    commandContext->EndBlit();
                     MTLRenderPassDescriptor* descriptor = CreateMTLRenderPassDescriptor(cmd);
-                    EncodeRenderPass(commandBuffer, descriptor, &encoders, cmd->width, cmd->height);
+                    EncodeRenderPass(commandContext, descriptor, cmd->width, cmd->height);
                 } break;
 
                 case Command::CopyBufferToBuffer: {
                     CopyBufferToBufferCmd* copy = mCommands.NextCommand<CopyBufferToBufferCmd>();
 
-                    encoders.EnsureBlit(commandBuffer);
-                    [encoders.blit copyFromBuffer:ToBackend(copy->source)->GetMTLBuffer()
-                                     sourceOffset:copy->sourceOffset
-                                         toBuffer:ToBackend(copy->destination)->GetMTLBuffer()
-                                destinationOffset:copy->destinationOffset
-                                             size:copy->size];
+                    [commandContext->EnsureBlit()
+                           copyFromBuffer:ToBackend(copy->source)->GetMTLBuffer()
+                             sourceOffset:copy->sourceOffset
+                                 toBuffer:ToBackend(copy->destination)->GetMTLBuffer()
+                        destinationOffset:copy->destinationOffset
+                                     size:copy->size];
                 } break;
 
                 case Command::CopyBufferToTexture: {
@@ -651,18 +630,17 @@
                         dst.origin, copySize, texture->GetFormat(), virtualSizeAtLevel,
                         buffer->GetSize(), src.offset, src.rowPitch, src.imageHeight);
 
-                    encoders.EnsureBlit(commandBuffer);
                     for (uint32_t i = 0; i < splittedCopies.count; ++i) {
                         const TextureBufferCopySplit::CopyInfo& copyInfo = splittedCopies.copies[i];
-                        [encoders.blit copyFromBuffer:buffer->GetMTLBuffer()
-                                         sourceOffset:copyInfo.bufferOffset
-                                    sourceBytesPerRow:copyInfo.bytesPerRow
-                                  sourceBytesPerImage:copyInfo.bytesPerImage
-                                           sourceSize:copyInfo.copyExtent
-                                            toTexture:texture->GetMTLTexture()
-                                     destinationSlice:dst.arrayLayer
-                                     destinationLevel:dst.mipLevel
-                                    destinationOrigin:copyInfo.textureOrigin];
+                        [commandContext->EnsureBlit() copyFromBuffer:buffer->GetMTLBuffer()
+                                                        sourceOffset:copyInfo.bufferOffset
+                                                   sourceBytesPerRow:copyInfo.bytesPerRow
+                                                 sourceBytesPerImage:copyInfo.bytesPerImage
+                                                          sourceSize:copyInfo.copyExtent
+                                                           toTexture:texture->GetMTLTexture()
+                                                    destinationSlice:dst.arrayLayer
+                                                    destinationLevel:dst.mipLevel
+                                                   destinationOrigin:copyInfo.textureOrigin];
                     }
                 } break;
 
@@ -679,18 +657,17 @@
                         src.origin, copySize, texture->GetFormat(), virtualSizeAtLevel,
                         buffer->GetSize(), dst.offset, dst.rowPitch, dst.imageHeight);
 
-                    encoders.EnsureBlit(commandBuffer);
                     for (uint32_t i = 0; i < splittedCopies.count; ++i) {
                         const TextureBufferCopySplit::CopyInfo& copyInfo = splittedCopies.copies[i];
-                        [encoders.blit copyFromTexture:texture->GetMTLTexture()
-                                           sourceSlice:src.arrayLayer
-                                           sourceLevel:src.mipLevel
-                                          sourceOrigin:copyInfo.textureOrigin
-                                            sourceSize:copyInfo.copyExtent
-                                              toBuffer:buffer->GetMTLBuffer()
-                                     destinationOffset:copyInfo.bufferOffset
-                                destinationBytesPerRow:copyInfo.bytesPerRow
-                              destinationBytesPerImage:copyInfo.bytesPerImage];
+                        [commandContext->EnsureBlit() copyFromTexture:texture->GetMTLTexture()
+                                                          sourceSlice:src.arrayLayer
+                                                          sourceLevel:src.mipLevel
+                                                         sourceOrigin:copyInfo.textureOrigin
+                                                           sourceSize:copyInfo.copyExtent
+                                                             toBuffer:buffer->GetMTLBuffer()
+                                                    destinationOffset:copyInfo.bufferOffset
+                                               destinationBytesPerRow:copyInfo.bytesPerRow
+                                             destinationBytesPerImage:copyInfo.bytesPerImage];
                     }
                 } break;
 
@@ -700,40 +677,38 @@
                     Texture* srcTexture = ToBackend(copy->source.texture.Get());
                     Texture* dstTexture = ToBackend(copy->destination.texture.Get());
 
-                    encoders.EnsureBlit(commandBuffer);
-
-                    [encoders.blit copyFromTexture:srcTexture->GetMTLTexture()
-                                       sourceSlice:copy->source.arrayLayer
-                                       sourceLevel:copy->source.mipLevel
-                                      sourceOrigin:MakeMTLOrigin(copy->source.origin)
-                                        sourceSize:MakeMTLSize(copy->copySize)
-                                         toTexture:dstTexture->GetMTLTexture()
-                                  destinationSlice:copy->destination.arrayLayer
-                                  destinationLevel:copy->destination.mipLevel
-                                 destinationOrigin:MakeMTLOrigin(copy->destination.origin)];
+                    [commandContext->EnsureBlit()
+                          copyFromTexture:srcTexture->GetMTLTexture()
+                              sourceSlice:copy->source.arrayLayer
+                              sourceLevel:copy->source.mipLevel
+                             sourceOrigin:MakeMTLOrigin(copy->source.origin)
+                               sourceSize:MakeMTLSize(copy->copySize)
+                                toTexture:dstTexture->GetMTLTexture()
+                         destinationSlice:copy->destination.arrayLayer
+                         destinationLevel:copy->destination.mipLevel
+                        destinationOrigin:MakeMTLOrigin(copy->destination.origin)];
                 } break;
 
                 default: { UNREACHABLE(); } break;
             }
         }
 
-        encoders.Finish();
+        commandContext->EndBlit();
     }
 
-    void CommandBuffer::EncodeComputePass(id<MTLCommandBuffer> commandBuffer) {
+    void CommandBuffer::EncodeComputePass(CommandRecordingContext* commandContext) {
         ComputePipeline* lastPipeline = nullptr;
         StorageBufferLengthTracker storageBufferLengths = {};
         BindGroupTracker bindGroups(&storageBufferLengths);
 
-        // Will be autoreleased
-        id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
+        id<MTLComputeCommandEncoder> encoder = commandContext->BeginCompute();
 
         Command type;
         while (mCommands.NextCommandId(&type)) {
             switch (type) {
                 case Command::EndComputePass: {
                     mCommands.NextCommand<EndComputePassCmd>();
-                    [encoder endEncoding];
+                    commandContext->EndCompute();
                     return;
                 } break;
 
@@ -813,12 +788,11 @@
         UNREACHABLE();
     }
 
-    void CommandBuffer::EncodeRenderPass(id<MTLCommandBuffer> commandBuffer,
+    void CommandBuffer::EncodeRenderPass(CommandRecordingContext* commandContext,
                                          MTLRenderPassDescriptor* mtlRenderPass,
-                                         GlobalEncoders* globalEncoders,
                                          uint32_t width,
                                          uint32_t height) {
-        ASSERT(mtlRenderPass && globalEncoders);
+        ASSERT(mtlRenderPass);
 
         Device* device = ToBackend(GetDevice());
 
@@ -861,17 +835,16 @@
             // If we need to use a temporary resolve texture we need to copy the result of MSAA
             // resolve back to the true resolve targets.
             if (useTemporaryResolveTexture) {
-                EncodeRenderPass(commandBuffer, mtlRenderPass, globalEncoders, width, height);
+                EncodeRenderPass(commandContext, mtlRenderPass, width, height);
                 for (uint32_t i = 0; i < kMaxColorAttachments; ++i) {
                     if (trueResolveTextures[i] == nil) {
                         continue;
                     }
 
                     ASSERT(temporaryResolveTextures[i] != nil);
-                    CopyIntoTrueResolveTarget(commandBuffer, trueResolveTextures[i],
+                    CopyIntoTrueResolveTarget(commandContext, trueResolveTextures[i],
                                               trueResolveLevels[i], trueResolveSlices[i],
-                                              temporaryResolveTextures[i], width, height,
-                                              globalEncoders);
+                                              temporaryResolveTextures[i], width, height);
                 }
                 return;
             }
@@ -896,16 +869,16 @@
 
             // If we found a store + MSAA resolve we need to resolve in a different render pass.
             if (hasStoreAndMSAAResolve) {
-                EncodeRenderPass(commandBuffer, mtlRenderPass, globalEncoders, width, height);
-                ResolveInAnotherRenderPass(commandBuffer, mtlRenderPass, resolveTextures);
+                EncodeRenderPass(commandContext, mtlRenderPass, width, height);
+                ResolveInAnotherRenderPass(commandContext, mtlRenderPass, resolveTextures);
                 return;
             }
         }
 
-        EncodeRenderPassInternal(commandBuffer, mtlRenderPass, width, height);
+        EncodeRenderPassInternal(commandContext, mtlRenderPass, width, height);
     }
 
-    void CommandBuffer::EncodeRenderPassInternal(id<MTLCommandBuffer> commandBuffer,
+    void CommandBuffer::EncodeRenderPassInternal(CommandRecordingContext* commandContext,
                                                  MTLRenderPassDescriptor* mtlRenderPass,
                                                  uint32_t width,
                                                  uint32_t height) {
@@ -916,9 +889,7 @@
         StorageBufferLengthTracker storageBufferLengths = {};
         BindGroupTracker bindGroups(&storageBufferLengths);
 
-        // This will be autoreleased
-        id<MTLRenderCommandEncoder> encoder =
-            [commandBuffer renderCommandEncoderWithDescriptor:mtlRenderPass];
+        id<MTLRenderCommandEncoder> encoder = commandContext->BeginRender(mtlRenderPass);
 
         auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
             switch (type) {
@@ -1068,7 +1039,7 @@
             switch (type) {
                 case Command::EndRenderPass: {
                     mCommands.NextCommand<EndRenderPassCmd>();
-                    [encoder endEncoding];
+                    commandContext->EndRender();
                     return;
                 } break;
 
diff --git a/src/dawn_native/metal/CommandRecordingContext.h b/src/dawn_native/metal/CommandRecordingContext.h
new file mode 100644
index 0000000..531681b
--- /dev/null
+++ b/src/dawn_native/metal/CommandRecordingContext.h
@@ -0,0 +1,59 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#ifndef DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
+#define DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
+
+#import <Metal/Metal.h>
+
+namespace dawn_native { namespace metal {
+
+    // This class wraps a MTLCommandBuffer and tracks which Metal encoder is open.
+    // Only one encoder may be open at a time.
+    class CommandRecordingContext {
+      public:
+        CommandRecordingContext();
+        CommandRecordingContext(id<MTLCommandBuffer> commands);
+
+        CommandRecordingContext(const CommandRecordingContext& rhs) = delete;
+        CommandRecordingContext& operator=(const CommandRecordingContext& rhs) = delete;
+
+        CommandRecordingContext(CommandRecordingContext&& rhs);
+        CommandRecordingContext& operator=(CommandRecordingContext&& rhs);
+
+        ~CommandRecordingContext();
+
+        id<MTLCommandBuffer> GetCommands();
+
+        id<MTLCommandBuffer> AcquireCommands();
+
+        id<MTLBlitCommandEncoder> EnsureBlit();
+        void EndBlit();
+
+        id<MTLComputeCommandEncoder> BeginCompute();
+        void EndCompute();
+
+        id<MTLRenderCommandEncoder> BeginRender(MTLRenderPassDescriptor* descriptor);
+        void EndRender();
+
+      private:
+        id<MTLCommandBuffer> mCommands = nil;
+        id<MTLBlitCommandEncoder> mBlit = nil;
+        id<MTLComputeCommandEncoder> mCompute = nil;
+        id<MTLRenderCommandEncoder> mRender = nil;
+        bool mInEncoder = false;
+    };
+
+}}  // namespace dawn_native::metal
+
+#endif  // DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
diff --git a/src/dawn_native/metal/CommandRecordingContext.mm b/src/dawn_native/metal/CommandRecordingContext.mm
new file mode 100644
index 0000000..df4d6f8
--- /dev/null
+++ b/src/dawn_native/metal/CommandRecordingContext.mm
@@ -0,0 +1,113 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/metal/CommandRecordingContext.h"
+
+#include "common/Assert.h"
+
+namespace dawn_native { namespace metal {
+
+    CommandRecordingContext::CommandRecordingContext() = default;
+
+    CommandRecordingContext::CommandRecordingContext(id<MTLCommandBuffer> commands)
+        : mCommands(commands) {
+    }
+
+    CommandRecordingContext::CommandRecordingContext(CommandRecordingContext&& rhs)
+        : mCommands(rhs.AcquireCommands()) {
+    }
+
+    CommandRecordingContext& CommandRecordingContext::operator=(CommandRecordingContext&& rhs) {
+        mCommands = rhs.AcquireCommands();
+        return *this;
+    }
+
+    CommandRecordingContext::~CommandRecordingContext() {
+        // Commands must be acquired.
+        ASSERT(mCommands == nil);
+    }
+
+    id<MTLCommandBuffer> CommandRecordingContext::GetCommands() {
+        return mCommands;
+    }
+
+    id<MTLCommandBuffer> CommandRecordingContext::AcquireCommands() {
+        ASSERT(!mInEncoder);
+
+        id<MTLCommandBuffer> commands = mCommands;
+        mCommands = nil;
+        return commands;
+    }
+
+    id<MTLBlitCommandEncoder> CommandRecordingContext::EnsureBlit() {
+        ASSERT(mCommands != nil);
+
+        if (mBlit == nil) {
+            ASSERT(!mInEncoder);
+            mInEncoder = true;
+            mBlit = [mCommands blitCommandEncoder];
+        }
+        return mBlit;
+    }
+
+    void CommandRecordingContext::EndBlit() {
+        ASSERT(mCommands != nil);
+
+        if (mBlit != nil) {
+            [mBlit endEncoding];
+            mBlit = nil;  // This will be autoreleased.
+            mInEncoder = false;
+        }
+    }
+
+    id<MTLComputeCommandEncoder> CommandRecordingContext::BeginCompute() {
+        ASSERT(mCommands != nil);
+        ASSERT(mCompute == nil);
+        ASSERT(!mInEncoder);
+
+        mInEncoder = true;
+        mCompute = [mCommands computeCommandEncoder];
+        return mCompute;
+    }
+
+    void CommandRecordingContext::EndCompute() {
+        ASSERT(mCommands != nil);
+        ASSERT(mCompute != nil);
+
+        [mCompute endEncoding];
+        mCompute = nil;  // This will be autoreleased.
+        mInEncoder = false;
+    }
+
+    id<MTLRenderCommandEncoder> CommandRecordingContext::BeginRender(
+        MTLRenderPassDescriptor* descriptor) {
+        ASSERT(mCommands != nil);
+        ASSERT(mRender == nil);
+        ASSERT(!mInEncoder);
+
+        mInEncoder = true;
+        mRender = [mCommands renderCommandEncoderWithDescriptor:descriptor];
+        return mRender;
+    }
+
+    void CommandRecordingContext::EndRender() {
+        ASSERT(mCommands != nil);
+        ASSERT(mRender != nil);
+
+        [mRender endEncoding];
+        mRender = nil;  // This will be autoreleased.
+        mInEncoder = false;
+    }
+
+}}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index 2219ab6..667eb56 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -19,6 +19,7 @@
 
 #include "common/Serial.h"
 #include "dawn_native/Device.h"
+#include "dawn_native/metal/CommandRecordingContext.h"
 #include "dawn_native/metal/Forward.h"
 
 #import <IOSurface/IOSurfaceRef.h>
@@ -48,7 +49,7 @@
         id<MTLDevice> GetMTLDevice();
         id<MTLCommandQueue> GetMTLQueue();
 
-        id<MTLCommandBuffer> GetPendingCommandBuffer();
+        CommandRecordingContext* GetPendingCommandContext();
         Serial GetPendingCommandSerial() const override;
         void SubmitPendingCommandBuffer();
 
@@ -98,7 +99,7 @@
         std::unique_ptr<MapRequestTracker> mMapTracker;
 
         Serial mLastSubmittedSerial = 0;
-        id<MTLCommandBuffer> mPendingCommands = nil;
+        CommandRecordingContext mCommandContext;
 
         // The completed serial is updated in a Metal completion handler that can be fired on a
         // different thread, so it needs to be atomic.
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 6cd73a9..54b1a19 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -144,7 +144,7 @@
         mDynamicUploader->Deallocate(completedSerial);
         mMapTracker->Tick(completedSerial);
 
-        if (mPendingCommands != nil) {
+        if (mCommandContext.GetCommands() != nil) {
             SubmitPendingCommandBuffer();
         } else if (completedSerial == mLastSubmittedSerial) {
             // If there's no GPU work in flight we still need to artificially increment the serial
@@ -164,45 +164,43 @@
         return mCommandQueue;
     }
 
-    id<MTLCommandBuffer> Device::GetPendingCommandBuffer() {
-        TRACE_EVENT0(GetPlatform(), General, "DeviceMTL::GetPendingCommandBuffer");
-        if (mPendingCommands == nil) {
-            mPendingCommands = [mCommandQueue commandBuffer];
-            [mPendingCommands retain];
+    CommandRecordingContext* Device::GetPendingCommandContext() {
+        if (mCommandContext.GetCommands() == nil) {
+            TRACE_EVENT0(GetPlatform(), General, "[MTLCommandQueue commandBuffer]");
+            mCommandContext = CommandRecordingContext([mCommandQueue commandBuffer]);
         }
-        return mPendingCommands;
+        return &mCommandContext;
     }
 
     void Device::SubmitPendingCommandBuffer() {
-        if (mPendingCommands == nil) {
+        if (mCommandContext.GetCommands() == nil) {
             return;
         }
 
         mLastSubmittedSerial++;
 
+        // Ensure the blit encoder is ended. It may have been opened to perform a lazy clear or
+        // buffer upload.
+        mCommandContext.EndBlit();
+
+        // Acquire and retain the pending commands. We must keep them alive until scheduled.
+        id<MTLCommandBuffer> pendingCommands = [mCommandContext.AcquireCommands() retain];
+
         // Replace mLastSubmittedCommands with the mutex held so we avoid races between the
         // schedule handler and this code.
         {
             std::lock_guard<std::mutex> lock(mLastSubmittedCommandsMutex);
             [mLastSubmittedCommands release];
-            mLastSubmittedCommands = mPendingCommands;
+            mLastSubmittedCommands = pendingCommands;
         }
 
-        // Ok, ObjC blocks are weird. My understanding is that local variables are captured by
-        // value so this-> works as expected. However it is unclear how members are captured, (are
-        // they captured using this-> or by value?). To be safe we copy members to local variables
-        // to ensure they are captured "by value".
-
-        // Free mLastSubmittedCommands as soon as it is scheduled so that it doesn't hold
-        // references to its resources. Make a local copy of pendingCommands first so it is
-        // captured "by-value" by the block.
-        id<MTLCommandBuffer> pendingCommands = mPendingCommands;
-
-        [mPendingCommands addScheduledHandler:^(id<MTLCommandBuffer>) {
+        [pendingCommands addScheduledHandler:^(id<MTLCommandBuffer>) {
             // This is DRF because we hold the mutex for mLastSubmittedCommands and pendingCommands
             // is a local value (and not the member itself).
             std::lock_guard<std::mutex> lock(mLastSubmittedCommandsMutex);
             if (this->mLastSubmittedCommands == pendingCommands) {
+                // Free mLastSubmittedCommands as soon as it is scheduled so that it doesn't hold
+                // references to its resources.
                 [this->mLastSubmittedCommands release];
                 this->mLastSubmittedCommands = nil;
             }
@@ -211,7 +209,7 @@
         // Update the completed serial once the completed handler is fired. Make a local copy of
         // mLastSubmittedSerial so it is captured by value.
         Serial pendingSerial = mLastSubmittedSerial;
-        [mPendingCommands addCompletedHandler:^(id<MTLCommandBuffer>) {
+        [pendingCommands addCompletedHandler:^(id<MTLCommandBuffer>) {
             TRACE_EVENT_ASYNC_END0(GetPlatform(), GPUWork, "DeviceMTL::SubmitPendingCommandBuffer",
                                    pendingSerial);
             ASSERT(pendingSerial > mCompletedSerial.load());
@@ -220,8 +218,7 @@
 
         TRACE_EVENT_ASYNC_BEGIN0(GetPlatform(), GPUWork, "DeviceMTL::SubmitPendingCommandBuffer",
                                  pendingSerial);
-        [mPendingCommands commit];
-        mPendingCommands = nil;
+        [pendingCommands commit];
     }
 
     MapRequestTracker* Device::GetMapTracker() const {
@@ -242,15 +239,11 @@
                                                uint64_t size) {
         id<MTLBuffer> uploadBuffer = ToBackend(source)->GetBufferHandle();
         id<MTLBuffer> buffer = ToBackend(destination)->GetMTLBuffer();
-        id<MTLCommandBuffer> commandBuffer = GetPendingCommandBuffer();
-        id<MTLBlitCommandEncoder> encoder = [commandBuffer blitCommandEncoder];
-        [encoder copyFromBuffer:uploadBuffer
-                   sourceOffset:sourceOffset
-                       toBuffer:buffer
-              destinationOffset:destinationOffset
-                           size:size];
-        [encoder endEncoding];
-
+        [GetPendingCommandContext()->EnsureBlit() copyFromBuffer:uploadBuffer
+                                                    sourceOffset:sourceOffset
+                                                        toBuffer:buffer
+                                               destinationOffset:destinationOffset
+                                                            size:size];
         return {};
     }
 
@@ -273,8 +266,7 @@
     }
 
     MaybeError Device::WaitForIdleForDestruction() {
-        [mPendingCommands release];
-        mPendingCommands = nil;
+        [mCommandContext.AcquireCommands() release];
 
         // Wait for all commands to be finished so we can free resources
         while (GetCompletedCommandSerial() != mLastSubmittedSerial) {
@@ -285,10 +277,7 @@
     }
 
     void Device::Destroy() {
-        if (mPendingCommands != nil) {
-            [mPendingCommands release];
-            mPendingCommands = nil;
-        }
+        [mCommandContext.AcquireCommands() release];
 
         mMapTracker = nullptr;
         mDynamicUploader = nullptr;
diff --git a/src/dawn_native/metal/QueueMTL.mm b/src/dawn_native/metal/QueueMTL.mm
index dd360e9..7c5967a 100644
--- a/src/dawn_native/metal/QueueMTL.mm
+++ b/src/dawn_native/metal/QueueMTL.mm
@@ -27,11 +27,11 @@
     MaybeError Queue::SubmitImpl(uint32_t commandCount, CommandBufferBase* const* commands) {
         Device* device = ToBackend(GetDevice());
         device->Tick();
-        id<MTLCommandBuffer> commandBuffer = device->GetPendingCommandBuffer();
+        CommandRecordingContext* commandContext = device->GetPendingCommandContext();
 
         TRACE_EVENT_BEGIN0(GetDevice()->GetPlatform(), Recording, "CommandBufferMTL::FillCommands");
         for (uint32_t i = 0; i < commandCount; ++i) {
-            ToBackend(commands[i])->FillCommands(commandBuffer);
+            ToBackend(commands[i])->FillCommands(commandContext);
         }
         TRACE_EVENT_END0(GetDevice()->GetPlatform(), Recording, "CommandBufferMTL::FillCommands");