d3d11: use Texture::Copy for copy data to staging texture used in CopyTextureToBuffer

Bug: dawn:1768
Change-Id: I7f6488cf54764585eef61766c3d2a829ef186783
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/128580
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Peng Huang <penghuang@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/dawn/native/d3d11/CommandBufferD3D11.cpp b/src/dawn/native/d3d11/CommandBufferD3D11.cpp
index 05cf570..359e46f 100644
--- a/src/dawn/native/d3d11/CommandBufferD3D11.cpp
+++ b/src/dawn/native/d3d11/CommandBufferD3D11.cpp
@@ -367,54 +367,29 @@
                 DAWN_TRY(ToBackend(src.texture)
                              ->EnsureSubresourceContentInitialized(commandContext, subresources));
 
-                // Create a staging texture.
-                // TODO(dawn:1768): use compute shader to copy data from texture to buffer.
-                D3D11_TEXTURE2D_DESC stagingTextureDesc;
-                stagingTextureDesc.Width = copy->copySize.width;
-                stagingTextureDesc.Height = copy->copySize.height;
-                stagingTextureDesc.MipLevels = 1;
-                stagingTextureDesc.ArraySize = copy->copySize.depthOrArrayLayers;
-                stagingTextureDesc.Format = ToBackend(src.texture)->GetD3D11Format();
-                stagingTextureDesc.SampleDesc.Count = 1;
-                stagingTextureDesc.SampleDesc.Quality = 0;
-                stagingTextureDesc.Usage = D3D11_USAGE_STAGING;
-                stagingTextureDesc.BindFlags = 0;
-                stagingTextureDesc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
-                stagingTextureDesc.MiscFlags = 0;
+                TextureDescriptor desc = {};
+                desc.label = "CopyTextureToBufferStaging";
+                desc.dimension = src.texture->GetDimension();
+                desc.size.width = copy->copySize.width;
+                desc.size.height = copy->copySize.height;
+                desc.size.depthOrArrayLayers = copy->copySize.depthOrArrayLayers;
+                desc.format = src.texture->GetFormat().format;
+                desc.mipLevelCount = 1;
+                desc.sampleCount = 1;
 
-                ComPtr<ID3D11Texture2D> stagingTexture;
-                DAWN_TRY(CheckHRESULT(commandContext->GetD3D11Device()->CreateTexture2D(
-                                          &stagingTextureDesc, nullptr, &stagingTexture),
-                                      "D3D11 create staging texture"));
+                Ref<Texture> stagingTexture;
+                DAWN_TRY_ASSIGN(stagingTexture,
+                                Texture::CreateStaging(ToBackend(GetDevice()), &desc));
 
-                uint32_t subresource =
-                    src.texture->GetSubresourceIndex(src.mipLevel, src.origin.z, src.aspect);
+                CopyTextureToTextureCmd copyTextureToBufferCmd;
+                copyTextureToBufferCmd.source = src;
+                copyTextureToBufferCmd.destination.texture = stagingTexture.Get();
+                copyTextureToBufferCmd.destination.origin = {0, 0, 0};
+                copyTextureToBufferCmd.destination.mipLevel = 0;
+                copyTextureToBufferCmd.destination.aspect = src.aspect;
+                copyTextureToBufferCmd.copySize = copy->copySize;
 
-                if (src.texture->GetDimension() != wgpu::TextureDimension::e2D) {
-                    return DAWN_UNIMPLEMENTED_ERROR(
-                        "CopyTextureToBuffer is not implemented for non-2D textures");
-                } else {
-                    for (uint32_t z = 0; z < copy->copySize.depthOrArrayLayers; ++z) {
-                        // Copy the texture to the staging texture.
-                        if (src.texture->GetFormat().HasDepthOrStencil()) {
-                            d3d11DeviceContext1->CopySubresourceRegion(
-                                stagingTexture.Get(), z, 0, 0, 0,
-                                ToBackend(src.texture)->GetD3D11Resource(), subresource, nullptr);
-                        } else {
-                            D3D11_BOX srcBox;
-                            srcBox.left = src.origin.x;
-                            srcBox.right = src.origin.x + copy->copySize.width;
-                            srcBox.top = src.origin.y;
-                            srcBox.bottom = src.origin.y + copy->copySize.height;
-                            srcBox.front = 0;
-                            srcBox.back = 1;
-
-                            d3d11DeviceContext1->CopySubresourceRegion(
-                                stagingTexture.Get(), z, 0, 0, 0,
-                                ToBackend(src.texture)->GetD3D11Resource(), subresource, &srcBox);
-                        }
-                    }
-                }
+                DAWN_TRY(Texture::Copy(commandContext, &copyTextureToBufferCmd));
 
                 Buffer* buffer = ToBackend(dst.buffer.Get());
                 Buffer::ScopedMap scopedDstMap;
@@ -428,7 +403,7 @@
                     // TODO(dawn:1705): avoid blocking the CPU.
                     D3D11_MAPPED_SUBRESOURCE mappedResource;
                     DAWN_TRY(
-                        CheckHRESULT(d3d11DeviceContext1->Map(stagingTexture.Get(), z,
+                        CheckHRESULT(d3d11DeviceContext1->Map(stagingTexture->GetD3D11Resource(), z,
                                                               D3D11_MAP_READ, 0, &mappedResource),
                                      "D3D11 map staging texture"));
 
@@ -462,7 +437,7 @@
                             }
                         }
                     }
-                    d3d11DeviceContext1->Unmap(stagingTexture.Get(), z);
+                    d3d11DeviceContext1->Unmap(stagingTexture->GetD3D11Resource(), z);
                 }
 
                 dst.buffer->MarkUsedInPendingCommands();
diff --git a/src/dawn/native/d3d11/TextureD3D11.cpp b/src/dawn/native/d3d11/TextureD3D11.cpp
index ecdba76..eb3fda3 100644
--- a/src/dawn/native/d3d11/TextureD3D11.cpp
+++ b/src/dawn/native/d3d11/TextureD3D11.cpp
@@ -53,7 +53,8 @@
 
 // static
 ResultOrError<Ref<Texture>> Texture::Create(Device* device, const TextureDescriptor* descriptor) {
-    Ref<Texture> texture = AcquireRef(new Texture(device, descriptor, TextureState::OwnedInternal));
+    Ref<Texture> texture = AcquireRef(
+        new Texture(device, descriptor, TextureState::OwnedInternal, /*isStaging=*/false));
     DAWN_TRY(texture->InitializeAsInternalTexture());
     return std::move(texture);
 }
@@ -62,74 +63,84 @@
 ResultOrError<Ref<Texture>> Texture::Create(Device* device,
                                             const TextureDescriptor* descriptor,
                                             ComPtr<ID3D11Resource> d3d11Texture) {
-    Ref<Texture> dawnTexture =
-        AcquireRef(new Texture(device, descriptor, TextureState::OwnedExternal));
+    Ref<Texture> dawnTexture = AcquireRef(
+        new Texture(device, descriptor, TextureState::OwnedExternal, /*isStaging=*/false));
     DAWN_TRY(dawnTexture->InitializeAsSwapChainTexture(std::move(d3d11Texture)));
     return std::move(dawnTexture);
 }
 
+ResultOrError<Ref<Texture>> Texture::CreateStaging(Device* device,
+                                                   const TextureDescriptor* descriptor) {
+    Ref<Texture> texture = AcquireRef(
+        new Texture(device, descriptor, TextureState::OwnedInternal, /*isStaging=*/true));
+    DAWN_TRY(texture->InitializeAsInternalTexture());
+    return std::move(texture);
+}
+
+template <typename T>
+T Texture::GetD3D11TextureDesc() const {
+    T desc;
+
+    if constexpr (std::is_same<T, D3D11_TEXTURE1D_DESC>::value) {
+        desc.Width = GetSize().width;
+        desc.ArraySize = GetArrayLayers();
+        desc.MiscFlags = 0;
+    } else if constexpr (std::is_same<T, D3D11_TEXTURE2D_DESC>::value) {
+        desc.Width = GetSize().width;
+        desc.Height = GetSize().height;
+        desc.ArraySize = GetArrayLayers();
+        desc.SampleDesc.Count = GetSampleCount();
+        desc.SampleDesc.Quality = 0;
+        desc.MiscFlags = 0;
+        if (GetArrayLayers() >= 6 && desc.Width == desc.Height) {
+            // Texture layers are more than 6. It can be used as a cube map.
+            desc.MiscFlags |= D3D11_RESOURCE_MISC_TEXTURECUBE;
+        }
+    } else if constexpr (std::is_same<T, D3D11_TEXTURE3D_DESC>::value) {
+        desc.Width = GetSize().width;
+        desc.Height = GetSize().height;
+        desc.Depth = GetSize().depthOrArrayLayers;
+        desc.MiscFlags = 0;
+    }
+
+    desc.MipLevels = static_cast<UINT16>(GetNumMipLevels());
+    desc.Format = d3d::DXGITextureFormat(GetFormat().format);
+    desc.Usage = mIsStaging ? D3D11_USAGE_STAGING : D3D11_USAGE_DEFAULT;
+    desc.BindFlags = D3D11TextureBindFlags(GetUsage(), GetFormat());
+    constexpr UINT kCPUReadWriteFlags = D3D11_CPU_ACCESS_READ | D3D11_CPU_ACCESS_WRITE;
+    desc.CPUAccessFlags = mIsStaging ? kCPUReadWriteFlags : 0;
+
+    return desc;
+}
+
 MaybeError Texture::InitializeAsInternalTexture() {
     Device* device = ToBackend(GetDevice());
 
-    DXGI_FORMAT dxgiFormat = d3d::DXGITextureFormat(GetFormat().format);
-
     switch (GetDimension()) {
         case wgpu::TextureDimension::e1D: {
-            D3D11_TEXTURE1D_DESC textureDescriptor;
-            textureDescriptor.Width = GetSize().width;
-            textureDescriptor.MipLevels = static_cast<UINT16>(GetNumMipLevels());
-            textureDescriptor.ArraySize = 1;
-            textureDescriptor.Format = dxgiFormat;
-            textureDescriptor.Usage = D3D11_USAGE_DEFAULT;
-            textureDescriptor.BindFlags = D3D11TextureBindFlags(GetUsage(), GetFormat());
-            textureDescriptor.CPUAccessFlags = 0;
-            textureDescriptor.MiscFlags = 0;
+            D3D11_TEXTURE1D_DESC desc = GetD3D11TextureDesc<D3D11_TEXTURE1D_DESC>();
             ComPtr<ID3D11Texture1D> d3d11Texture1D;
-            DAWN_TRY(CheckOutOfMemoryHRESULT(device->GetD3D11Device()->CreateTexture1D(
-                                                 &textureDescriptor, nullptr, &d3d11Texture1D),
-                                             "D3D11 create texture1d"));
+            DAWN_TRY(CheckOutOfMemoryHRESULT(
+                device->GetD3D11Device()->CreateTexture1D(&desc, nullptr, &d3d11Texture1D),
+                "D3D11 create texture1d"));
             mD3d11Resource = std::move(d3d11Texture1D);
             break;
         }
         case wgpu::TextureDimension::e2D: {
-            D3D11_TEXTURE2D_DESC textureDescriptor;
-            textureDescriptor.Width = GetSize().width;
-            textureDescriptor.Height = GetSize().height;
-            textureDescriptor.MipLevels = static_cast<UINT16>(GetNumMipLevels());
-            textureDescriptor.ArraySize = GetArrayLayers();
-            textureDescriptor.Format = dxgiFormat;
-            textureDescriptor.SampleDesc.Count = GetSampleCount();
-            textureDescriptor.SampleDesc.Quality = 0;
-            textureDescriptor.Usage = D3D11_USAGE_DEFAULT;
-            textureDescriptor.BindFlags = D3D11TextureBindFlags(GetUsage(), GetFormat());
-            textureDescriptor.CPUAccessFlags = 0;
-            textureDescriptor.MiscFlags = 0;
-            if (GetArrayLayers() >= 6) {
-                // Texture layers are more than 6. It can be used as a cube map.
-                textureDescriptor.MiscFlags |= D3D11_RESOURCE_MISC_TEXTURECUBE;
-            }
+            D3D11_TEXTURE2D_DESC desc = GetD3D11TextureDesc<D3D11_TEXTURE2D_DESC>();
             ComPtr<ID3D11Texture2D> d3d11Texture2D;
-            DAWN_TRY(CheckOutOfMemoryHRESULT(device->GetD3D11Device()->CreateTexture2D(
-                                                 &textureDescriptor, nullptr, &d3d11Texture2D),
-                                             "D3D11 create texture2d"));
+            DAWN_TRY(CheckOutOfMemoryHRESULT(
+                device->GetD3D11Device()->CreateTexture2D(&desc, nullptr, &d3d11Texture2D),
+                "D3D11 create texture2d"));
             mD3d11Resource = std::move(d3d11Texture2D);
             break;
         }
         case wgpu::TextureDimension::e3D: {
-            D3D11_TEXTURE3D_DESC textureDescriptor;
-            textureDescriptor.Width = GetSize().width;
-            textureDescriptor.Height = GetSize().height;
-            textureDescriptor.Depth = GetSize().depthOrArrayLayers;
-            textureDescriptor.MipLevels = static_cast<UINT16>(GetNumMipLevels());
-            textureDescriptor.Format = dxgiFormat;
-            textureDescriptor.Usage = D3D11_USAGE_DEFAULT;
-            textureDescriptor.BindFlags = D3D11TextureBindFlags(GetUsage(), GetFormat());
-            textureDescriptor.CPUAccessFlags = 0;
-            textureDescriptor.MiscFlags = 0;
+            D3D11_TEXTURE3D_DESC desc = GetD3D11TextureDesc<D3D11_TEXTURE3D_DESC>();
             ComPtr<ID3D11Texture3D> d3d11Texture3D;
-            DAWN_TRY(CheckOutOfMemoryHRESULT(device->GetD3D11Device()->CreateTexture3D(
-                                                 &textureDescriptor, nullptr, &d3d11Texture3D),
-                                             "D3D11 create texture3d"));
+            DAWN_TRY(CheckOutOfMemoryHRESULT(
+                device->GetD3D11Device()->CreateTexture3D(&desc, nullptr, &d3d11Texture3D),
+                "D3D11 create texture3d"));
             mD3d11Resource = std::move(d3d11Texture3D);
             break;
         }
@@ -152,8 +163,11 @@
     return {};
 }
 
-Texture::Texture(Device* device, const TextureDescriptor* descriptor, TextureState state)
-    : TextureBase(device, descriptor, state) {}
+Texture::Texture(Device* device,
+                 const TextureDescriptor* descriptor,
+                 TextureState state,
+                 bool isStaging)
+    : TextureBase(device, descriptor, state), mIsStaging(isStaging) {}
 
 Texture::~Texture() = default;
 
@@ -386,10 +400,17 @@
     if (IsCompleteSubresourceCopiedTo(dst.texture.Get(), copy->copySize, dst.mipLevel)) {
         dst.texture->SetIsSubresourceContentInitialized(true, subresources);
     } else {
+        // Partial update subresource of a depth/stencil texture is not allowed.
+        DAWN_ASSERT(!dst.texture->GetFormat().HasDepthOrStencil());
         DAWN_TRY(ToBackend(dst.texture)
                      ->EnsureSubresourceContentInitialized(commandContext, subresources));
     }
 
+    bool isWholeTextureCopy =
+        src.texture->GetSize() == copy->copySize && dst.texture->GetSize() == copy->copySize;
+    // Partial update subresource of a depth/stencil texture is not allowed.
+    DAWN_ASSERT(isWholeTextureCopy || !dst.texture->GetFormat().HasDepthOrStencil());
+
     D3D11_BOX srcBox;
     srcBox.left = src.origin.x;
     srcBox.right = src.origin.x + copy->copySize.width;
@@ -402,7 +423,8 @@
 
     commandContext->GetD3D11DeviceContext1()->CopySubresourceRegion(
         ToBackend(dst.texture)->GetD3D11Resource(), dst.mipLevel, dst.origin.x, dst.origin.y,
-        dst.origin.z, ToBackend(src.texture)->GetD3D11Resource(), subresource, &srcBox);
+        dst.origin.z, ToBackend(src.texture)->GetD3D11Resource(), subresource,
+        isWholeTextureCopy ? nullptr : &srcBox);
 
     return {};
 }
diff --git a/src/dawn/native/d3d11/TextureD3D11.h b/src/dawn/native/d3d11/TextureD3D11.h
index 3eba67c..9190a8c 100644
--- a/src/dawn/native/d3d11/TextureD3D11.h
+++ b/src/dawn/native/d3d11/TextureD3D11.h
@@ -37,6 +37,8 @@
     static ResultOrError<Ref<Texture>> Create(Device* device,
                                               const TextureDescriptor* descriptor,
                                               ComPtr<ID3D11Resource> d3d11Texture);
+    static ResultOrError<Ref<Texture>> CreateStaging(Device* device,
+                                                     const TextureDescriptor* descriptor);
 
     DXGI_FORMAT GetD3D11Format() const;
     ID3D11Resource* GetD3D11Resource() const;
@@ -62,9 +64,14 @@
     static MaybeError Copy(CommandRecordingContext* commandContext, CopyTextureToTextureCmd* copy);
 
   private:
-    Texture(Device* device, const TextureDescriptor* descriptor, TextureState state);
+    Texture(Device* device,
+            const TextureDescriptor* descriptor,
+            TextureState state,
+            bool isStaging);
     ~Texture() override;
-    using TextureBase::TextureBase;
+
+    template <typename T>
+    T GetD3D11TextureDesc() const;
 
     MaybeError InitializeAsInternalTexture();
     MaybeError InitializeAsSwapChainTexture(ComPtr<ID3D11Resource> d3d11Texture);
@@ -79,6 +86,7 @@
                      const SubresourceRange& range,
                      TextureBase::ClearValue clearValue);
 
+    const bool mIsStaging = false;
     ComPtr<ID3D11Resource> mD3d11Resource;
 };