Refactor the emulation of store and MSAA resolve

This patch refactors the implementation of the emulation on "store and
MSAA resolve" store operation by moving all the function calls related
to this toggle into one recursive function EncoderRenderPass(). This
refactoring will also make it easier to implement more workarounds on
Metal render pass.

BUG=dawn:56

Change-Id: Ifc737407001e55863835ab994b735e088beda8c6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/7220
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_native/metal/CommandBufferMTL.h b/src/dawn_native/metal/CommandBufferMTL.h
index 89420b8..4b08a9f 100644
--- a/src/dawn_native/metal/CommandBufferMTL.h
+++ b/src/dawn_native/metal/CommandBufferMTL.h
@@ -28,6 +28,7 @@
 namespace dawn_native { namespace metal {
 
     class Device;
+    struct GlobalEncoders;
 
     class CommandBuffer : public CommandBufferBase {
       public:
@@ -38,7 +39,16 @@
 
       private:
         void EncodeComputePass(id<MTLCommandBuffer> commandBuffer);
-        void EncodeRenderPass(id<MTLCommandBuffer> commandBuffer, BeginRenderPassCmd* renderPass);
+        void EncodeRenderPass(id<MTLCommandBuffer> commandBuffer,
+                              MTLRenderPassDescriptor* mtlRenderPass,
+                              GlobalEncoders* globalEncoders,
+                              uint32_t width,
+                              uint32_t height);
+
+        void EncodeRenderPassInternal(id<MTLCommandBuffer> commandBuffer,
+                                      MTLRenderPassDescriptor* mtlRenderPass,
+                                      uint32_t width,
+                                      uint32_t height);
 
         CommandIterator mCommands;
     };
diff --git a/src/dawn_native/metal/CommandBufferMTL.mm b/src/dawn_native/metal/CommandBufferMTL.mm
index b979963..505cd85 100644
--- a/src/dawn_native/metal/CommandBufferMTL.mm
+++ b/src/dawn_native/metal/CommandBufferMTL.mm
@@ -27,29 +27,27 @@
 
 namespace dawn_native { namespace metal {
 
+    struct GlobalEncoders {
+        id<MTLBlitCommandEncoder> blit = nil;
+
+        void Finish() {
+            if (blit != nil) {
+                [blit endEncoding];
+                blit = nil;  // This will be autoreleased.
+            }
+        }
+
+        void EnsureBlit(id<MTLCommandBuffer> commandBuffer) {
+            if (blit == nil) {
+                blit = [commandBuffer blitCommandEncoder];
+            }
+        }
+    };
+
     namespace {
 
-        struct GlobalEncoders {
-            id<MTLBlitCommandEncoder> blit = nil;
-
-            void Finish() {
-                if (blit != nil) {
-                    [blit endEncoding];
-                    blit = nil;  // This will be autoreleased.
-                }
-            }
-
-            void EnsureBlit(id<MTLCommandBuffer> commandBuffer) {
-                if (blit == nil) {
-                    blit = [commandBuffer blitCommandEncoder];
-                }
-            }
-        };
-
         // Creates an autoreleased MTLRenderPassDescriptor matching desc
-        MTLRenderPassDescriptor* CreateMTLRenderPassDescriptor(
-            BeginRenderPassCmd* renderPass,
-            bool shouldEmulateStoreAndMSAAResolve) {
+        MTLRenderPassDescriptor* CreateMTLRenderPassDescriptor(BeginRenderPassCmd* renderPass) {
             MTLRenderPassDescriptor* descriptor = [MTLRenderPassDescriptor renderPassDescriptor];
 
             for (uint32_t i : IterateBitSet(renderPass->colorAttachmentsSet)) {
@@ -70,8 +68,7 @@
                 descriptor.colorAttachments[i].slice = attachmentInfo.view->GetBaseArrayLayer();
 
                 if (attachmentInfo.storeOp == dawn::StoreOp::Store) {
-                    if (attachmentInfo.resolveTarget.Get() != nullptr &&
-                        !shouldEmulateStoreAndMSAAResolve) {
+                    if (attachmentInfo.resolveTarget.Get() != nullptr) {
                         descriptor.colorAttachments[i].resolveTexture =
                             ToBackend(attachmentInfo.resolveTarget->GetTexture())->GetMTLTexture();
                         descriptor.colorAttachments[i].resolveLevel =
@@ -122,41 +119,32 @@
             return descriptor;
         }
 
-        // Do MSAA resolve in another render pass.
-        void ResolveInAnotherRenderPass(id<MTLCommandBuffer> commandBuffer,
-                                        BeginRenderPassCmd* renderPass) {
-            ASSERT(renderPass->sampleCount > 1);
-            MTLRenderPassDescriptor* renderPassForResolve = nil;
-            for (uint32_t i : IterateBitSet(renderPass->colorAttachmentsSet)) {
-                auto& attachmentInfo = renderPass->colorAttachments[i];
-                if (attachmentInfo.resolveTarget.Get() == nil ||
-                    attachmentInfo.storeOp != dawn::StoreOp::Store) {
+        void ResolveInAnotherRenderPass(
+            id<MTLCommandBuffer> commandBuffer,
+            const MTLRenderPassDescriptor* mtlRenderPass,
+            const std::array<id<MTLTexture>, kMaxColorAttachments>& resolveTextures) {
+            MTLRenderPassDescriptor* mtlRenderPassForResolve =
+                [MTLRenderPassDescriptor renderPassDescriptor];
+            for (uint32_t i = 0; i < kMaxColorAttachments; ++i) {
+                if (resolveTextures[i] == nil) {
                     continue;
                 }
 
-                if (renderPassForResolve == nil) {
-                    renderPassForResolve = [MTLRenderPassDescriptor renderPassDescriptor];
-                }
-                renderPassForResolve.colorAttachments[i].texture =
-                    ToBackend(attachmentInfo.view->GetTexture())->GetMTLTexture();
-                renderPassForResolve.colorAttachments[i].level = 0;
-                renderPassForResolve.colorAttachments[i].slice = 0;
-
-                renderPassForResolve.colorAttachments[i].storeAction =
+                mtlRenderPassForResolve.colorAttachments[i].texture =
+                    mtlRenderPass.colorAttachments[i].texture;
+                mtlRenderPassForResolve.colorAttachments[i].loadAction = MTLLoadActionLoad;
+                mtlRenderPassForResolve.colorAttachments[i].storeAction =
                     MTLStoreActionMultisampleResolve;
-                renderPassForResolve.colorAttachments[i].resolveTexture =
-                    ToBackend(attachmentInfo.resolveTarget->GetTexture())->GetMTLTexture();
-                renderPassForResolve.colorAttachments[i].resolveLevel =
-                    attachmentInfo.resolveTarget->GetBaseMipLevel();
-                renderPassForResolve.colorAttachments[i].resolveSlice =
-                    attachmentInfo.resolveTarget->GetBaseArrayLayer();
+                mtlRenderPassForResolve.colorAttachments[i].resolveTexture = resolveTextures[i];
+                mtlRenderPassForResolve.colorAttachments[i].resolveLevel =
+                    mtlRenderPass.colorAttachments[i].resolveLevel;
+                mtlRenderPassForResolve.colorAttachments[i].resolveSlice =
+                    mtlRenderPass.colorAttachments[i].resolveSlice;
             }
 
-            if (renderPassForResolve != nil) {
-                id<MTLRenderCommandEncoder> encoder =
-                    [commandBuffer renderCommandEncoderWithDescriptor:renderPassForResolve];
-                [encoder endEncoding];
-            }
+            id<MTLRenderCommandEncoder> encoder =
+                [commandBuffer renderCommandEncoderWithDescriptor:mtlRenderPassForResolve];
+            [encoder endEncoding];
         }
 
         // Handles a call to SetBindGroup, directing the commands to the correct encoder.
@@ -286,7 +274,8 @@
                 case Command::BeginRenderPass: {
                     BeginRenderPassCmd* cmd = mCommands.NextCommand<BeginRenderPassCmd>();
                     encoders.Finish();
-                    EncodeRenderPass(commandBuffer, cmd);
+                    MTLRenderPassDescriptor* descriptor = CreateMTLRenderPassDescriptor(cmd);
+                    EncodeRenderPass(commandBuffer, descriptor, &encoders, cmd->width, cmd->height);
                 } break;
 
                 case Command::CopyBufferToBuffer: {
@@ -630,7 +619,46 @@
     }
 
     void CommandBuffer::EncodeRenderPass(id<MTLCommandBuffer> commandBuffer,
-                                         BeginRenderPassCmd* renderPassCmd) {
+                                         MTLRenderPassDescriptor* mtlRenderPass,
+                                         GlobalEncoders* globalEncoders,
+                                         uint32_t width,
+                                         uint32_t height) {
+        ASSERT(mtlRenderPass && globalEncoders);
+
+        Device* device = ToBackend(GetDevice());
+
+        // Handle Store + MSAA resolve workaround (Toggle EmulateStoreAndMSAAResolve).
+        if (device->IsToggleEnabled(Toggle::EmulateStoreAndMSAAResolve)) {
+            bool hasStoreAndMSAAResolve = false;
+
+            // Remove any store + MSAA resolve and remember them.
+            std::array<id<MTLTexture>, kMaxColorAttachments> resolveTextures = {};
+            for (uint32_t i = 0; i < kMaxColorAttachments; ++i) {
+                if (mtlRenderPass.colorAttachments[i].storeAction ==
+                    MTLStoreActionStoreAndMultisampleResolve) {
+                    hasStoreAndMSAAResolve = true;
+                    resolveTextures[i] = mtlRenderPass.colorAttachments[i].resolveTexture;
+
+                    mtlRenderPass.colorAttachments[i].storeAction = MTLStoreActionStore;
+                    mtlRenderPass.colorAttachments[i].resolveTexture = nil;
+                }
+            }
+
+            // If we found a store + MSAA resolve we need to resolve in a different render pass.
+            if (hasStoreAndMSAAResolve) {
+                EncodeRenderPass(commandBuffer, mtlRenderPass, globalEncoders, width, height);
+                ResolveInAnotherRenderPass(commandBuffer, mtlRenderPass, resolveTextures);
+                return;
+            }
+        }
+
+        EncodeRenderPassInternal(commandBuffer, mtlRenderPass, width, height);
+    }
+
+    void CommandBuffer::EncodeRenderPassInternal(id<MTLCommandBuffer> commandBuffer,
+                                                 MTLRenderPassDescriptor* mtlRenderPass,
+                                                 uint32_t width,
+                                                 uint32_t height) {
         RenderPipeline* lastPipeline = nullptr;
         id<MTLBuffer> indexBuffer = nil;
         uint32_t indexBufferBaseOffset = 0;
@@ -638,13 +666,9 @@
         std::array<uint32_t, kMaxPushConstants> vertexPushConstants;
         std::array<uint32_t, kMaxPushConstants> fragmentPushConstants;
 
-        bool shouldEmulateStoreAndMSAAResolve =
-            GetDevice()->IsToggleEnabled(Toggle::EmulateStoreAndMSAAResolve);
         // This will be autoreleased
-        id<MTLRenderCommandEncoder> encoder = [commandBuffer
-            renderCommandEncoderWithDescriptor:CreateMTLRenderPassDescriptor(
-                                                   renderPassCmd,
-                                                   shouldEmulateStoreAndMSAAResolve)];
+        id<MTLRenderCommandEncoder> encoder =
+            [commandBuffer renderCommandEncoderWithDescriptor:mtlRenderPass];
 
         // Set default values for push constants
         vertexPushConstants.fill(0);
@@ -663,9 +687,6 @@
                 case Command::EndRenderPass: {
                     mCommands.NextCommand<EndRenderPassCmd>();
                     [encoder endEncoding];
-                    if (renderPassCmd->sampleCount > 1 && shouldEmulateStoreAndMSAAResolve) {
-                        ResolveInAnotherRenderPass(commandBuffer, renderPassCmd);
-                    }
                     return;
                 } break;
 
@@ -770,12 +791,12 @@
                     rect.height = cmd->height;
 
                     // The scissor rect x + width must be <= render pass width
-                    if ((rect.x + rect.width) > renderPassCmd->width) {
-                        rect.width = renderPassCmd->width - rect.x;
+                    if ((rect.x + rect.width) > width) {
+                        rect.width = width - rect.x;
                     }
                     // The scissor rect y + height must be <= render pass height
-                    if ((rect.y + rect.height > renderPassCmd->height)) {
-                        rect.height = renderPassCmd->height - rect.y;
+                    if ((rect.y + rect.height > height)) {
+                        rect.height = height - rect.y;
                     }
 
                     [encoder setScissorRect:rect];