Reland "Metal: Add CommandRecordingContext"

This is a reland of 2b3975f808fd6f5afc5a52e58a3dcd5e73984b17

The previous CL failed to retain autoreleased ObjC objects which
should live longer than the autoreleasepool block. This reland fixes
the issue and adds tests for it.

Original change's description:
> Metal: Add CommandRecordingContext
>
> Introduces the idea of a CommandRecordingContext to the Metal backend,
> similar to other backends. This is a class to track which Metal encoder
> is open on the device-global pending MTLCommandBuffer.
> It will be needed to open/close encoders for lazy clearing.
>
> Bug: dawn:145
> Change-Id: Ief6b71a079d73943677d2b61382d1c36b88a4f87
> Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/14780
> Reviewed-by: Corentin Wallez <cwallez@chromium.org>
> Reviewed-by: Kai Ninomiya <kainino@chromium.org>
> Commit-Queue: Austin Eng <enga@chromium.org>

Bug: dawn:145
Change-Id: I67494b35225ce8f6443a3fa9787d054522e5d422
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/15042
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn
index 866b2d4..b6201ce 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -358,6 +358,8 @@
       "src/dawn_native/metal/BufferMTL.mm",
       "src/dawn_native/metal/CommandBufferMTL.h",
       "src/dawn_native/metal/CommandBufferMTL.mm",
+      "src/dawn_native/metal/CommandRecordingContext.h",
+      "src/dawn_native/metal/CommandRecordingContext.mm",
       "src/dawn_native/metal/ComputePipelineMTL.h",
       "src/dawn_native/metal/ComputePipelineMTL.mm",
       "src/dawn_native/metal/DeviceMTL.h",
@@ -967,7 +969,7 @@
 }
 
 source_set("dawn_white_box_tests_sources") {
-  configs += [ "${dawn_root}/src/common:dawn_internal" ]
+  configs += [ ":libdawn_native_internal" ]
   testonly = true
 
   deps = [
@@ -1002,6 +1004,10 @@
     sources += [ "src/tests/white_box/D3D12SmallTextureTests.cpp" ]
   }
 
+  if (dawn_enable_metal) {
+    sources += [ "src/tests/white_box/MetalAutoreleasePoolTests.mm" ]
+  }
+
   if (dawn_enable_opengl) {
     deps += [ ":dawn_glfw" ]
   }
diff --git a/src/dawn_native/metal/CommandBufferMTL.h b/src/dawn_native/metal/CommandBufferMTL.h
index 640d196..67a1313 100644
--- a/src/dawn_native/metal/CommandBufferMTL.h
+++ b/src/dawn_native/metal/CommandBufferMTL.h
@@ -26,25 +26,24 @@
 
 namespace dawn_native { namespace metal {
 
+    class CommandRecordingContext;
     class Device;
-    struct GlobalEncoders;
 
     class CommandBuffer : public CommandBufferBase {
       public:
         CommandBuffer(CommandEncoder* encoder, const CommandBufferDescriptor* descriptor);
         ~CommandBuffer();
 
-        void FillCommands(id<MTLCommandBuffer> commandBuffer);
+        void FillCommands(CommandRecordingContext* commandContext);
 
       private:
-        void EncodeComputePass(id<MTLCommandBuffer> commandBuffer);
-        void EncodeRenderPass(id<MTLCommandBuffer> commandBuffer,
+        void EncodeComputePass(CommandRecordingContext* commandContext);
+        void EncodeRenderPass(CommandRecordingContext* commandContext,
                               MTLRenderPassDescriptor* mtlRenderPass,
-                              GlobalEncoders* globalEncoders,
                               uint32_t width,
                               uint32_t height);
 
-        void EncodeRenderPassInternal(id<MTLCommandBuffer> commandBuffer,
+        void EncodeRenderPassInternal(CommandRecordingContext* commandContext,
                                       MTLRenderPassDescriptor* mtlRenderPass,
                                       uint32_t width,
                                       uint32_t height);
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index 77596dd..866a6fe 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -29,23 +29,6 @@
 
 namespace dawn_native { namespace metal {
 
-    struct GlobalEncoders {
-        id<MTLBlitCommandEncoder> blit = nil;
-
-        void Finish() {
-            if (blit != nil) {
-                [blit endEncoding];
-                blit = nil;  // This will be autoreleased.
-            }
-        }
-
-        void EnsureBlit(id<MTLCommandBuffer> commandBuffer) {
-            if (blit == nil) {
-                blit = [commandBuffer blitCommandEncoder];
-            }
-        }
-    };
-
     namespace {
 
         // Allows this file to use MTLStoreActionStoreAndMultismapleResolve because the logic is
@@ -133,7 +116,7 @@
 
         // Helper function for Toggle EmulateStoreAndMSAAResolve
         void ResolveInAnotherRenderPass(
-            id<MTLCommandBuffer> commandBuffer,
+            CommandRecordingContext* commandContext,
             const MTLRenderPassDescriptor* mtlRenderPass,
             const std::array<id<MTLTexture>, kMaxColorAttachments>& resolveTextures) {
             MTLRenderPassDescriptor* mtlRenderPassForResolve =
@@ -155,9 +138,8 @@
                     mtlRenderPass.colorAttachments[i].resolveSlice;
             }
 
-            id<MTLRenderCommandEncoder> encoder =
-                [commandBuffer renderCommandEncoderWithDescriptor:mtlRenderPassForResolve];
-            [encoder endEncoding];
+            commandContext->BeginRender(mtlRenderPassForResolve);
+            commandContext->EndRender();
         }
 
         // Helper functions for Toggle AlwaysResolveIntoZeroLevelAndLayer
@@ -182,24 +164,22 @@
             return resolveTexture;
         }
 
-        void CopyIntoTrueResolveTarget(id<MTLCommandBuffer> commandBuffer,
+        void CopyIntoTrueResolveTarget(CommandRecordingContext* commandContext,
                                        id<MTLTexture> mtlTrueResolveTexture,
                                        uint32_t trueResolveLevel,
                                        uint32_t trueResolveSlice,
                                        id<MTLTexture> temporaryResolveTexture,
                                        uint32_t width,
-                                       uint32_t height,
-                                       GlobalEncoders* encoders) {
-            encoders->EnsureBlit(commandBuffer);
-            [encoders->blit copyFromTexture:temporaryResolveTexture
-                                sourceSlice:0
-                                sourceLevel:0
-                               sourceOrigin:MTLOriginMake(0, 0, 0)
-                                 sourceSize:MTLSizeMake(width, height, 1)
-                                  toTexture:mtlTrueResolveTexture
-                           destinationSlice:trueResolveSlice
-                           destinationLevel:trueResolveLevel
-                          destinationOrigin:MTLOriginMake(0, 0, 0)];
+                                       uint32_t height) {
+            [commandContext->EnsureBlit() copyFromTexture:temporaryResolveTexture
+                                              sourceSlice:0
+                                              sourceLevel:0
+                                             sourceOrigin:MTLOriginMake(0, 0, 0)
+                                               sourceSize:MTLSizeMake(width, height, 1)
+                                                toTexture:mtlTrueResolveTexture
+                                         destinationSlice:trueResolveSlice
+                                         destinationLevel:trueResolveLevel
+                                        destinationOrigin:MTLOriginMake(0, 0, 0)];
         }
 
         // Metal uses a physical addressing mode which means buffers in the shading language are
@@ -608,34 +588,33 @@
         FreeCommands(&mCommands);
     }
 
-    void CommandBuffer::FillCommands(id<MTLCommandBuffer> commandBuffer) {
-        GlobalEncoders encoders;
-
+    void CommandBuffer::FillCommands(CommandRecordingContext* commandContext) {
         Command type;
         while (mCommands.NextCommandId(&type)) {
             switch (type) {
                 case Command::BeginComputePass: {
                     mCommands.NextCommand<BeginComputePassCmd>();
-                    encoders.Finish();
-                    EncodeComputePass(commandBuffer);
+
+                    commandContext->EndBlit();
+                    EncodeComputePass(commandContext);
                 } break;
 
                 case Command::BeginRenderPass: {
                     BeginRenderPassCmd* cmd = mCommands.NextCommand<BeginRenderPassCmd>();
-                    encoders.Finish();
+                    commandContext->EndBlit();
                     MTLRenderPassDescriptor* descriptor = CreateMTLRenderPassDescriptor(cmd);
-                    EncodeRenderPass(commandBuffer, descriptor, &encoders, cmd->width, cmd->height);
+                    EncodeRenderPass(commandContext, descriptor, cmd->width, cmd->height);
                 } break;
 
                 case Command::CopyBufferToBuffer: {
                     CopyBufferToBufferCmd* copy = mCommands.NextCommand<CopyBufferToBufferCmd>();
 
-                    encoders.EnsureBlit(commandBuffer);
-                    [encoders.blit copyFromBuffer:ToBackend(copy->source)->GetMTLBuffer()
-                                     sourceOffset:copy->sourceOffset
-                                         toBuffer:ToBackend(copy->destination)->GetMTLBuffer()
-                                destinationOffset:copy->destinationOffset
-                                             size:copy->size];
+                    [commandContext->EnsureBlit()
+                           copyFromBuffer:ToBackend(copy->source)->GetMTLBuffer()
+                             sourceOffset:copy->sourceOffset
+                                 toBuffer:ToBackend(copy->destination)->GetMTLBuffer()
+                        destinationOffset:copy->destinationOffset
+                                     size:copy->size];
                 } break;
 
                 case Command::CopyBufferToTexture: {
@@ -651,18 +630,17 @@
                         dst.origin, copySize, texture->GetFormat(), virtualSizeAtLevel,
                         buffer->GetSize(), src.offset, src.rowPitch, src.imageHeight);
 
-                    encoders.EnsureBlit(commandBuffer);
                     for (uint32_t i = 0; i < splittedCopies.count; ++i) {
                         const TextureBufferCopySplit::CopyInfo& copyInfo = splittedCopies.copies[i];
-                        [encoders.blit copyFromBuffer:buffer->GetMTLBuffer()
-                                         sourceOffset:copyInfo.bufferOffset
-                                    sourceBytesPerRow:copyInfo.bytesPerRow
-                                  sourceBytesPerImage:copyInfo.bytesPerImage
-                                           sourceSize:copyInfo.copyExtent
-                                            toTexture:texture->GetMTLTexture()
-                                     destinationSlice:dst.arrayLayer
-                                     destinationLevel:dst.mipLevel
-                                    destinationOrigin:copyInfo.textureOrigin];
+                        [commandContext->EnsureBlit() copyFromBuffer:buffer->GetMTLBuffer()
+                                                        sourceOffset:copyInfo.bufferOffset
+                                                   sourceBytesPerRow:copyInfo.bytesPerRow
+                                                 sourceBytesPerImage:copyInfo.bytesPerImage
+                                                          sourceSize:copyInfo.copyExtent
+                                                           toTexture:texture->GetMTLTexture()
+                                                    destinationSlice:dst.arrayLayer
+                                                    destinationLevel:dst.mipLevel
+                                                   destinationOrigin:copyInfo.textureOrigin];
                     }
                 } break;
 
@@ -679,18 +657,17 @@
                         src.origin, copySize, texture->GetFormat(), virtualSizeAtLevel,
                         buffer->GetSize(), dst.offset, dst.rowPitch, dst.imageHeight);
 
-                    encoders.EnsureBlit(commandBuffer);
                     for (uint32_t i = 0; i < splittedCopies.count; ++i) {
                         const TextureBufferCopySplit::CopyInfo& copyInfo = splittedCopies.copies[i];
-                        [encoders.blit copyFromTexture:texture->GetMTLTexture()
-                                           sourceSlice:src.arrayLayer
-                                           sourceLevel:src.mipLevel
-                                          sourceOrigin:copyInfo.textureOrigin
-                                            sourceSize:copyInfo.copyExtent
-                                              toBuffer:buffer->GetMTLBuffer()
-                                     destinationOffset:copyInfo.bufferOffset
-                                destinationBytesPerRow:copyInfo.bytesPerRow
-                              destinationBytesPerImage:copyInfo.bytesPerImage];
+                        [commandContext->EnsureBlit() copyFromTexture:texture->GetMTLTexture()
+                                                          sourceSlice:src.arrayLayer
+                                                          sourceLevel:src.mipLevel
+                                                         sourceOrigin:copyInfo.textureOrigin
+                                                           sourceSize:copyInfo.copyExtent
+                                                             toBuffer:buffer->GetMTLBuffer()
+                                                    destinationOffset:copyInfo.bufferOffset
+                                               destinationBytesPerRow:copyInfo.bytesPerRow
+                                             destinationBytesPerImage:copyInfo.bytesPerImage];
                     }
                 } break;
 
@@ -700,40 +677,38 @@
                     Texture* srcTexture = ToBackend(copy->source.texture.Get());
                     Texture* dstTexture = ToBackend(copy->destination.texture.Get());
 
-                    encoders.EnsureBlit(commandBuffer);
-
-                    [encoders.blit copyFromTexture:srcTexture->GetMTLTexture()
-                                       sourceSlice:copy->source.arrayLayer
-                                       sourceLevel:copy->source.mipLevel
-                                      sourceOrigin:MakeMTLOrigin(copy->source.origin)
-                                        sourceSize:MakeMTLSize(copy->copySize)
-                                         toTexture:dstTexture->GetMTLTexture()
-                                  destinationSlice:copy->destination.arrayLayer
-                                  destinationLevel:copy->destination.mipLevel
-                                 destinationOrigin:MakeMTLOrigin(copy->destination.origin)];
+                    [commandContext->EnsureBlit()
+                          copyFromTexture:srcTexture->GetMTLTexture()
+                              sourceSlice:copy->source.arrayLayer
+                              sourceLevel:copy->source.mipLevel
+                             sourceOrigin:MakeMTLOrigin(copy->source.origin)
+                               sourceSize:MakeMTLSize(copy->copySize)
+                                toTexture:dstTexture->GetMTLTexture()
+                         destinationSlice:copy->destination.arrayLayer
+                         destinationLevel:copy->destination.mipLevel
+                        destinationOrigin:MakeMTLOrigin(copy->destination.origin)];
                 } break;
 
                 default: { UNREACHABLE(); } break;
             }
         }
 
-        encoders.Finish();
+        commandContext->EndBlit();
     }
 
-    void CommandBuffer::EncodeComputePass(id<MTLCommandBuffer> commandBuffer) {
+    void CommandBuffer::EncodeComputePass(CommandRecordingContext* commandContext) {
         ComputePipeline* lastPipeline = nullptr;
         StorageBufferLengthTracker storageBufferLengths = {};
         BindGroupTracker bindGroups(&storageBufferLengths);
 
-        // Will be autoreleased
-        id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
+        id<MTLComputeCommandEncoder> encoder = commandContext->BeginCompute();
 
         Command type;
         while (mCommands.NextCommandId(&type)) {
             switch (type) {
                 case Command::EndComputePass: {
                     mCommands.NextCommand<EndComputePassCmd>();
-                    [encoder endEncoding];
+                    commandContext->EndCompute();
                     return;
                 } break;
 
@@ -813,12 +788,11 @@
         UNREACHABLE();
     }
 
-    void CommandBuffer::EncodeRenderPass(id<MTLCommandBuffer> commandBuffer,
+    void CommandBuffer::EncodeRenderPass(CommandRecordingContext* commandContext,
                                          MTLRenderPassDescriptor* mtlRenderPass,
-                                         GlobalEncoders* globalEncoders,
                                          uint32_t width,
                                          uint32_t height) {
-        ASSERT(mtlRenderPass && globalEncoders);
+        ASSERT(mtlRenderPass);
 
         Device* device = ToBackend(GetDevice());
 
@@ -861,17 +835,16 @@
             // If we need to use a temporary resolve texture we need to copy the result of MSAA
             // resolve back to the true resolve targets.
             if (useTemporaryResolveTexture) {
-                EncodeRenderPass(commandBuffer, mtlRenderPass, globalEncoders, width, height);
+                EncodeRenderPass(commandContext, mtlRenderPass, width, height);
                 for (uint32_t i = 0; i < kMaxColorAttachments; ++i) {
                     if (trueResolveTextures[i] == nil) {
                         continue;
                     }
 
                     ASSERT(temporaryResolveTextures[i] != nil);
-                    CopyIntoTrueResolveTarget(commandBuffer, trueResolveTextures[i],
+                    CopyIntoTrueResolveTarget(commandContext, trueResolveTextures[i],
                                               trueResolveLevels[i], trueResolveSlices[i],
-                                              temporaryResolveTextures[i], width, height,
-                                              globalEncoders);
+                                              temporaryResolveTextures[i], width, height);
                 }
                 return;
             }
@@ -896,16 +869,16 @@
 
             // If we found a store + MSAA resolve we need to resolve in a different render pass.
             if (hasStoreAndMSAAResolve) {
-                EncodeRenderPass(commandBuffer, mtlRenderPass, globalEncoders, width, height);
-                ResolveInAnotherRenderPass(commandBuffer, mtlRenderPass, resolveTextures);
+                EncodeRenderPass(commandContext, mtlRenderPass, width, height);
+                ResolveInAnotherRenderPass(commandContext, mtlRenderPass, resolveTextures);
                 return;
             }
         }
 
-        EncodeRenderPassInternal(commandBuffer, mtlRenderPass, width, height);
+        EncodeRenderPassInternal(commandContext, mtlRenderPass, width, height);
     }
 
-    void CommandBuffer::EncodeRenderPassInternal(id<MTLCommandBuffer> commandBuffer,
+    void CommandBuffer::EncodeRenderPassInternal(CommandRecordingContext* commandContext,
                                                  MTLRenderPassDescriptor* mtlRenderPass,
                                                  uint32_t width,
                                                  uint32_t height) {
@@ -916,9 +889,7 @@
         StorageBufferLengthTracker storageBufferLengths = {};
         BindGroupTracker bindGroups(&storageBufferLengths);
 
-        // This will be autoreleased
-        id<MTLRenderCommandEncoder> encoder =
-            [commandBuffer renderCommandEncoderWithDescriptor:mtlRenderPass];
+        id<MTLRenderCommandEncoder> encoder = commandContext->BeginRender(mtlRenderPass);
 
         auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
             switch (type) {
@@ -1068,7 +1039,7 @@
             switch (type) {
                 case Command::EndRenderPass: {
                     mCommands.NextCommand<EndRenderPassCmd>();
-                    [encoder endEncoding];
+                    commandContext->EndRender();
                     return;
                 } break;
 
diff --git a/src/dawn_native/metal/CommandRecordingContext.h b/src/dawn_native/metal/CommandRecordingContext.h
new file mode 100644
index 0000000..531681b
--- /dev/null
+++ b/src/dawn_native/metal/CommandRecordingContext.h
@@ -0,0 +1,59 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#ifndef DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
+#define DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
+
+#import <Metal/Metal.h>
+
+namespace dawn_native { namespace metal {
+
+    // This class wraps a MTLCommandBuffer and tracks which Metal encoder is open.
+    // Only one encoder may be open at a time.
+    class CommandRecordingContext {
+      public:
+        CommandRecordingContext();
+        CommandRecordingContext(id<MTLCommandBuffer> commands);
+
+        CommandRecordingContext(const CommandRecordingContext& rhs) = delete;
+        CommandRecordingContext& operator=(const CommandRecordingContext& rhs) = delete;
+
+        CommandRecordingContext(CommandRecordingContext&& rhs);
+        CommandRecordingContext& operator=(CommandRecordingContext&& rhs);
+
+        ~CommandRecordingContext();
+
+        id<MTLCommandBuffer> GetCommands();
+
+        id<MTLCommandBuffer> AcquireCommands();
+
+        id<MTLBlitCommandEncoder> EnsureBlit();
+        void EndBlit();
+
+        id<MTLComputeCommandEncoder> BeginCompute();
+        void EndCompute();
+
+        id<MTLRenderCommandEncoder> BeginRender(MTLRenderPassDescriptor* descriptor);
+        void EndRender();
+
+      private:
+        id<MTLCommandBuffer> mCommands = nil;
+        id<MTLBlitCommandEncoder> mBlit = nil;
+        id<MTLComputeCommandEncoder> mCompute = nil;
+        id<MTLRenderCommandEncoder> mRender = nil;
+        bool mInEncoder = false;
+    };
+
+}}  // namespace dawn_native::metal
+
+#endif  // DAWNNATIVE_METAL_COMMANDRECORDINGCONTEXT_H_
diff --git a/src/dawn_native/metal/CommandRecordingContext.mm b/src/dawn_native/metal/CommandRecordingContext.mm
new file mode 100644
index 0000000..3ede626
--- /dev/null
+++ b/src/dawn_native/metal/CommandRecordingContext.mm
@@ -0,0 +1,119 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/metal/CommandRecordingContext.h"
+
+#include "common/Assert.h"
+
+namespace dawn_native { namespace metal {
+
+    CommandRecordingContext::CommandRecordingContext() = default;
+
+    CommandRecordingContext::CommandRecordingContext(id<MTLCommandBuffer> commands)
+        : mCommands(commands) {
+    }
+
+    CommandRecordingContext::CommandRecordingContext(CommandRecordingContext&& rhs)
+        : mCommands(rhs.AcquireCommands()) {
+    }
+
+    CommandRecordingContext& CommandRecordingContext::operator=(CommandRecordingContext&& rhs) {
+        mCommands = rhs.AcquireCommands();
+        return *this;
+    }
+
+    CommandRecordingContext::~CommandRecordingContext() {
+        // Commands must be acquired.
+        ASSERT(mCommands == nil);
+    }
+
+    id<MTLCommandBuffer> CommandRecordingContext::GetCommands() {
+        return mCommands;
+    }
+
+    id<MTLCommandBuffer> CommandRecordingContext::AcquireCommands() {
+        ASSERT(!mInEncoder);
+
+        id<MTLCommandBuffer> commands = mCommands;
+        mCommands = nil;
+        return commands;
+    }
+
+    id<MTLBlitCommandEncoder> CommandRecordingContext::EnsureBlit() {
+        ASSERT(mCommands != nil);
+
+        if (mBlit == nil) {
+            ASSERT(!mInEncoder);
+            mInEncoder = true;
+            // The autorelease pool may drain before the encoder is ended. Retain so it stays alive.
+            mBlit = [[mCommands blitCommandEncoder] retain];
+        }
+        return mBlit;
+    }
+
+    void CommandRecordingContext::EndBlit() {
+        ASSERT(mCommands != nil);
+
+        if (mBlit != nil) {
+            [mBlit endEncoding];
+            [mBlit release];
+            mBlit = nil;
+            mInEncoder = false;
+        }
+    }
+
+    id<MTLComputeCommandEncoder> CommandRecordingContext::BeginCompute() {
+        ASSERT(mCommands != nil);
+        ASSERT(mCompute == nil);
+        ASSERT(!mInEncoder);
+
+        mInEncoder = true;
+        // The autorelease pool may drain before the encoder is ended. Retain so it stays alive.
+        mCompute = [[mCommands computeCommandEncoder] retain];
+        return mCompute;
+    }
+
+    void CommandRecordingContext::EndCompute() {
+        ASSERT(mCommands != nil);
+        ASSERT(mCompute != nil);
+
+        [mCompute endEncoding];
+        [mCompute release];
+        mCompute = nil;
+        mInEncoder = false;
+    }
+
+    id<MTLRenderCommandEncoder> CommandRecordingContext::BeginRender(
+        MTLRenderPassDescriptor* descriptor) {
+        ASSERT(mCommands != nil);
+        ASSERT(mRender == nil);
+        ASSERT(!mInEncoder);
+
+        mInEncoder = true;
+        // The autorelease pool may drain before the encoder is ended. Retain so it stays alive.
+        mRender = [[mCommands renderCommandEncoderWithDescriptor:descriptor] retain];
+        return mRender;
+    }
+
+    void CommandRecordingContext::EndRender() {
+        ASSERT(mCommands != nil);
+        ASSERT(mRender != nil);
+
+        [mRender endEncoding];
+        [mRender release];
+        mRender = nil;
+        mInEncoder = false;
+    }
+
+}}  // namespace dawn_native::metal
diff --git a/src/dawn_native/metal/DeviceMTL.h b/src/dawn_native/metal/DeviceMTL.h
index 2219ab6..667eb56 100644
--- a/src/dawn_native/metal/DeviceMTL.h
+++ b/src/dawn_native/metal/DeviceMTL.h
@@ -19,6 +19,7 @@
 
 #include "common/Serial.h"
 #include "dawn_native/Device.h"
+#include "dawn_native/metal/CommandRecordingContext.h"
 #include "dawn_native/metal/Forward.h"
 
 #import <IOSurface/IOSurfaceRef.h>
@@ -48,7 +49,7 @@
         id<MTLDevice> GetMTLDevice();
         id<MTLCommandQueue> GetMTLQueue();
 
-        id<MTLCommandBuffer> GetPendingCommandBuffer();
+        CommandRecordingContext* GetPendingCommandContext();
         Serial GetPendingCommandSerial() const override;
         void SubmitPendingCommandBuffer();
 
@@ -98,7 +99,7 @@
         std::unique_ptr<MapRequestTracker> mMapTracker;
 
         Serial mLastSubmittedSerial = 0;
-        id<MTLCommandBuffer> mPendingCommands = nil;
+        CommandRecordingContext mCommandContext;
 
         // The completed serial is updated in a Metal completion handler that can be fired on a
         // different thread, so it needs to be atomic.
diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm
index 6cd73a9..77faa4e 100644
--- a/src/dawn_native/metal/DeviceMTL.mm
+++ b/src/dawn_native/metal/DeviceMTL.mm
@@ -144,7 +144,7 @@
         mDynamicUploader->Deallocate(completedSerial);
         mMapTracker->Tick(completedSerial);
 
-        if (mPendingCommands != nil) {
+        if (mCommandContext.GetCommands() != nil) {
             SubmitPendingCommandBuffer();
         } else if (completedSerial == mLastSubmittedSerial) {
             // If there's no GPU work in flight we still need to artificially increment the serial
@@ -164,46 +164,43 @@
         return mCommandQueue;
     }
 
-    id<MTLCommandBuffer> Device::GetPendingCommandBuffer() {
-        TRACE_EVENT0(GetPlatform(), General, "DeviceMTL::GetPendingCommandBuffer");
-        if (mPendingCommands == nil) {
-            mPendingCommands = [mCommandQueue commandBuffer];
-            [mPendingCommands retain];
+    CommandRecordingContext* Device::GetPendingCommandContext() {
+        if (mCommandContext.GetCommands() == nil) {
+            TRACE_EVENT0(GetPlatform(), General, "[MTLCommandQueue commandBuffer]");
+            // The MTLCommandBuffer will be autoreleased by default.
+            // The autorelease pool may drain before the command buffer is submitted. Retain so it
+            // stays alive.
+            mCommandContext = CommandRecordingContext([[mCommandQueue commandBuffer] retain]);
         }
-        return mPendingCommands;
+        return &mCommandContext;
     }
 
     void Device::SubmitPendingCommandBuffer() {
-        if (mPendingCommands == nil) {
+        if (mCommandContext.GetCommands() == nil) {
             return;
         }
 
         mLastSubmittedSerial++;
 
+        // Ensure the blit encoder is ended. It may have been opened to perform a lazy clear or
+        // buffer upload.
+        mCommandContext.EndBlit();
+
+        // Acquire the pending command buffer, which is retained. It must be released later.
+        id<MTLCommandBuffer> pendingCommands = mCommandContext.AcquireCommands();
+
         // Replace mLastSubmittedCommands with the mutex held so we avoid races between the
         // schedule handler and this code.
         {
             std::lock_guard<std::mutex> lock(mLastSubmittedCommandsMutex);
-            [mLastSubmittedCommands release];
-            mLastSubmittedCommands = mPendingCommands;
+            mLastSubmittedCommands = pendingCommands;
         }
 
-        // Ok, ObjC blocks are weird. My understanding is that local variables are captured by
-        // value so this-> works as expected. However it is unclear how members are captured, (are
-        // they captured using this-> or by value?). To be safe we copy members to local variables
-        // to ensure they are captured "by value".
-
-        // Free mLastSubmittedCommands as soon as it is scheduled so that it doesn't hold
-        // references to its resources. Make a local copy of pendingCommands first so it is
-        // captured "by-value" by the block.
-        id<MTLCommandBuffer> pendingCommands = mPendingCommands;
-
-        [mPendingCommands addScheduledHandler:^(id<MTLCommandBuffer>) {
+        [pendingCommands addScheduledHandler:^(id<MTLCommandBuffer>) {
             // This is DRF because we hold the mutex for mLastSubmittedCommands and pendingCommands
             // is a local value (and not the member itself).
             std::lock_guard<std::mutex> lock(mLastSubmittedCommandsMutex);
             if (this->mLastSubmittedCommands == pendingCommands) {
-                [this->mLastSubmittedCommands release];
                 this->mLastSubmittedCommands = nil;
             }
         }];
@@ -211,7 +208,7 @@
         // Update the completed serial once the completed handler is fired. Make a local copy of
         // mLastSubmittedSerial so it is captured by value.
         Serial pendingSerial = mLastSubmittedSerial;
-        [mPendingCommands addCompletedHandler:^(id<MTLCommandBuffer>) {
+        [pendingCommands addCompletedHandler:^(id<MTLCommandBuffer>) {
             TRACE_EVENT_ASYNC_END0(GetPlatform(), GPUWork, "DeviceMTL::SubmitPendingCommandBuffer",
                                    pendingSerial);
             ASSERT(pendingSerial > mCompletedSerial.load());
@@ -220,8 +217,8 @@
 
         TRACE_EVENT_ASYNC_BEGIN0(GetPlatform(), GPUWork, "DeviceMTL::SubmitPendingCommandBuffer",
                                  pendingSerial);
-        [mPendingCommands commit];
-        mPendingCommands = nil;
+        [pendingCommands commit];
+        [pendingCommands release];
     }
 
     MapRequestTracker* Device::GetMapTracker() const {
@@ -242,15 +239,11 @@
                                                uint64_t size) {
         id<MTLBuffer> uploadBuffer = ToBackend(source)->GetBufferHandle();
         id<MTLBuffer> buffer = ToBackend(destination)->GetMTLBuffer();
-        id<MTLCommandBuffer> commandBuffer = GetPendingCommandBuffer();
-        id<MTLBlitCommandEncoder> encoder = [commandBuffer blitCommandEncoder];
-        [encoder copyFromBuffer:uploadBuffer
-                   sourceOffset:sourceOffset
-                       toBuffer:buffer
-              destinationOffset:destinationOffset
-                           size:size];
-        [encoder endEncoding];
-
+        [GetPendingCommandContext()->EnsureBlit() copyFromBuffer:uploadBuffer
+                                                    sourceOffset:sourceOffset
+                                                        toBuffer:buffer
+                                               destinationOffset:destinationOffset
+                                                            size:size];
         return {};
     }
 
@@ -273,8 +266,7 @@
     }
 
     MaybeError Device::WaitForIdleForDestruction() {
-        [mPendingCommands release];
-        mPendingCommands = nil;
+        [mCommandContext.AcquireCommands() release];
 
         // Wait for all commands to be finished so we can free resources
         while (GetCompletedCommandSerial() != mLastSubmittedSerial) {
@@ -285,10 +277,7 @@
     }
 
     void Device::Destroy() {
-        if (mPendingCommands != nil) {
-            [mPendingCommands release];
-            mPendingCommands = nil;
-        }
+        [mCommandContext.AcquireCommands() release];
 
         mMapTracker = nullptr;
         mDynamicUploader = nullptr;
diff --git a/src/dawn_native/metal/QueueMTL.mm b/src/dawn_native/metal/QueueMTL.mm
index dd360e9..7c5967a 100644
--- a/src/dawn_native/metal/QueueMTL.mm
+++ b/src/dawn_native/metal/QueueMTL.mm
@@ -27,11 +27,11 @@
     MaybeError Queue::SubmitImpl(uint32_t commandCount, CommandBufferBase* const* commands) {
         Device* device = ToBackend(GetDevice());
         device->Tick();
-        id<MTLCommandBuffer> commandBuffer = device->GetPendingCommandBuffer();
+        CommandRecordingContext* commandContext = device->GetPendingCommandContext();
 
         TRACE_EVENT_BEGIN0(GetDevice()->GetPlatform(), Recording, "CommandBufferMTL::FillCommands");
         for (uint32_t i = 0; i < commandCount; ++i) {
-            ToBackend(commands[i])->FillCommands(commandBuffer);
+            ToBackend(commands[i])->FillCommands(commandContext);
         }
         TRACE_EVENT_END0(GetDevice()->GetPlatform(), Recording, "CommandBufferMTL::FillCommands");
 
diff --git a/src/tests/white_box/MetalAutoreleasePoolTests.mm b/src/tests/white_box/MetalAutoreleasePoolTests.mm
new file mode 100644
index 0000000..a5c49d6
--- /dev/null
+++ b/src/tests/white_box/MetalAutoreleasePoolTests.mm
@@ -0,0 +1,61 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/DawnTest.h"
+
+#include "dawn_native/metal/DeviceMTL.h"
+
+using namespace dawn_native::metal;
+
+class MetalAutoreleasePoolTests : public DawnTest {
+  private:
+    void TestSetUp() override {
+        DAWN_SKIP_TEST_IF(UsesWire());
+
+        mMtlDevice = reinterpret_cast<Device*>(device.Get());
+    }
+
+  protected:
+    Device* mMtlDevice = nullptr;
+};
+
+// Test that the MTLCommandBuffer owned by the pending command context can
+// outlive an autoreleasepool block.
+TEST_P(MetalAutoreleasePoolTests, CommandBufferOutlivesAutorelease) {
+    @autoreleasepool {
+        // Get the recording context which will allocate a MTLCommandBuffer.
+        // It will get autoreleased at the end of this block.
+        mMtlDevice->GetPendingCommandContext();
+    }
+
+    // Submitting the command buffer should succeed.
+    mMtlDevice->SubmitPendingCommandBuffer();
+}
+
+// Test that the MTLBlitCommandEncoder owned by the pending command context
+// can outlive an autoreleasepool block.
+TEST_P(MetalAutoreleasePoolTests, EncoderOutlivesAutorelease) {
+    @autoreleasepool {
+        // Get the recording context which will allocate a MTLCommandBuffer.
+        // Begin a blit encoder.
+        // Both will get autoreleased at the end of this block.
+        mMtlDevice->GetPendingCommandContext()->EnsureBlit();
+    }
+
+    // Submitting the command buffer should succeed.
+    mMtlDevice->GetPendingCommandContext()->EndBlit();
+    mMtlDevice->SubmitPendingCommandBuffer();
+}
+
+DAWN_INSTANTIATE_TEST(MetalAutoreleasePoolTests, MetalBackend);