d3d11: Tighten keyed mutex acquire scope

Acquire and release keyed mutex in tight texture synchronization scope
to allow co-operative "concurrent" use with other clients. Keyed mutexes
are acquired in SynchronizeTextureBeforeUse and released on submit.

Bug: dawn:2311
Change-Id: I782c667f7b1d4aec84f1adadb7a97fb990554ae4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/168121
Auto-Submit: Sunny Sachanandani <sunnyps@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Sunny Sachanandani <sunnyps@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/include/dawn/native/D3DBackend.h b/include/dawn/native/D3DBackend.h
index a0ba6cc..bf1979e 100644
--- a/include/dawn/native/D3DBackend.h
+++ b/include/dawn/native/D3DBackend.h
@@ -51,6 +51,9 @@
     ::LUID adapterLUID;
 };
 
+// Chrome uses 0 as acquire key.
+static constexpr UINT64 kDXGIKeyedMutexAcquireKey = 0;
+
 struct DAWN_NATIVE_EXPORT ExternalImageDescriptorDXGISharedHandle : ExternalImageDescriptor {
   public:
     ExternalImageDescriptorDXGISharedHandle();
diff --git a/src/dawn/native/d3d/DeviceD3D.h b/src/dawn/native/d3d/DeviceD3D.h
index 9e0776b..cd1dd7b 100644
--- a/src/dawn/native/d3d/DeviceD3D.h
+++ b/src/dawn/native/d3d/DeviceD3D.h
@@ -40,6 +40,7 @@
 struct ExternalImageDXGIFenceDescriptor;
 class ExternalImageDXGIImpl;
 class Fence;
+class KeyedMutex;
 class PlatformFunctions;
 
 class Device : public DeviceBase {
@@ -64,6 +65,7 @@
     virtual Ref<TextureBase> CreateD3DExternalTexture(
         const UnpackedPtr<TextureDescriptor>& descriptor,
         ComPtr<IUnknown> d3dTexture,
+        ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex,
         std::vector<FenceAndSignalValue> waitFences,
         bool isSwapChainTexture,
         bool isInitialized) = 0;
diff --git a/src/dawn/native/d3d/ExternalImageDXGIImpl.cpp b/src/dawn/native/d3d/ExternalImageDXGIImpl.cpp
index bad0e62..c485413 100644
--- a/src/dawn/native/d3d/ExternalImageDXGIImpl.cpp
+++ b/src/dawn/native/d3d/ExternalImageDXGIImpl.cpp
@@ -83,15 +83,12 @@
                              textureDescriptor->nextInChain)
                              ->internalUsage;
     }
-
     // If the resource has IDXGIKeyedMutex interface, it will be used for synchronization.
-    // TODO(dawn:1906): remove the mDXGIKeyedMutex when it is not used in chrome.
     mD3DResource.As(&mDXGIKeyedMutex);
 }
 
 ExternalImageDXGIImpl::~ExternalImageDXGIImpl() {
     DAWN_ASSERT(mBackendDevice->IsLockedByCurrentThreadIfNeeded());
-    mDXGIKeyedMutexReleaser.reset();
     DestroyInternal();
 }
 
@@ -169,20 +166,9 @@
 
     Ref<TextureBase> texture =
         ToBackend(mBackendDevice.Get())
-            ->CreateD3DExternalTexture(Unpack(&textureDescriptor), mD3DResource,
+            ->CreateD3DExternalTexture(Unpack(&textureDescriptor), mD3DResource, mDXGIKeyedMutex,
                                        std::move(waitFences), descriptor->isSwapChainTexture,
                                        descriptor->isInitialized);
-
-    if (mDXGIKeyedMutex && mAccessCount == 0) {
-        HRESULT hr = mDXGIKeyedMutex->AcquireSync(kDXGIKeyedMutexAcquireKey, INFINITE);
-        if (FAILED(hr)) {
-            dawn::ErrorLog() << "Failed to acquire keyed mutex for external image";
-            return nullptr;
-        }
-        mDXGIKeyedMutexReleaser.emplace(mDXGIKeyedMutex);
-    }
-    ++mAccessCount;
-
     return ToAPI(ReturnToAPI(std::move(texture)));
 }
 
@@ -216,11 +202,6 @@
 
     signalFence->fenceHandle = sharedFence->GetFenceHandle();
     signalFence->fenceValue = static_cast<uint64_t>(fenceValue);
-
-    --mAccessCount;
-    if (mDXGIKeyedMutexReleaser && mAccessCount == 0) {
-        mDXGIKeyedMutexReleaser.reset();
-    }
 }
 
 }  // namespace dawn::native::d3d
diff --git a/src/dawn/native/d3d/ExternalImageDXGIImpl.h b/src/dawn/native/d3d/ExternalImageDXGIImpl.h
index f8cbc08..6a76eb5 100644
--- a/src/dawn/native/d3d/ExternalImageDXGIImpl.h
+++ b/src/dawn/native/d3d/ExternalImageDXGIImpl.h
@@ -47,6 +47,7 @@
 namespace dawn::native::d3d {
 
 class Device;
+class KeyedMutex;
 struct ExternalImageDXGIBeginAccessDescriptor;
 struct ExternalImageDXGIFenceDescriptor;
 struct ExternalImageDescriptorDXGISharedHandle;
@@ -86,20 +87,6 @@
     uint32_t mMipLevelCount;
     uint32_t mSampleCount;
     std::vector<wgpu::TextureFormat> mViewFormats;
-    uint32_t mAccessCount = 0;
-
-    // Chrome uses 0 as acquire key.
-    static constexpr UINT64 kDXGIKeyedMutexAcquireKey = 0;
-    class KeyedMutexReleaser : public NonCopyable {
-      public:
-        explicit KeyedMutexReleaser(ComPtr<IDXGIKeyedMutex> keyedMutex)
-            : mDXGIKeyedMutex(std::move(keyedMutex)) {}
-        ~KeyedMutexReleaser() { mDXGIKeyedMutex->ReleaseSync(kDXGIKeyedMutexAcquireKey); }
-
-      private:
-        const ComPtr<IDXGIKeyedMutex> mDXGIKeyedMutex;
-    };
-    std::optional<KeyedMutexReleaser> mDXGIKeyedMutexReleaser;
 };
 
 }  // namespace dawn::native::d3d
diff --git a/src/dawn/native/d3d/SharedTextureMemoryD3D.cpp b/src/dawn/native/d3d/SharedTextureMemoryD3D.cpp
index 58e2ec1..d83751f 100644
--- a/src/dawn/native/d3d/SharedTextureMemoryD3D.cpp
+++ b/src/dawn/native/d3d/SharedTextureMemoryD3D.cpp
@@ -40,13 +40,8 @@
 
 SharedTextureMemory::SharedTextureMemory(d3d::Device* device,
                                          const char* label,
-                                         SharedTextureMemoryProperties properties,
-                                         IUnknown* resource)
-    : SharedTextureMemoryBase(device, label, properties) {
-    // If the resource has IDXGIKeyedMutex interface, it will be used for synchronization.
-    // TODO(dawn:1906): remove the mDXGIKeyedMutex when it is not used in chrome.
-    resource->QueryInterface(IID_PPV_ARGS(&mDXGIKeyedMutex));
-}
+                                         SharedTextureMemoryProperties properties)
+    : SharedTextureMemoryBase(device, label, properties) {}
 
 MaybeError SharedTextureMemory::BeginAccessImpl(
     TextureBase* texture,
@@ -68,13 +63,6 @@
                 return DAWN_VALIDATION_ERROR("Unsupported fence type %s.", exportInfo.type);
         }
     }
-
-    // Acquire keyed mutex for the first access.
-    if (mDXGIKeyedMutex &&
-        (HasWriteAccess() || HasExclusiveReadAccess() || GetReadAccessCount() == 1)) {
-        DAWN_TRY(CheckHRESULT(mDXGIKeyedMutex->AcquireSync(kDXGIKeyedMutexAcquireKey, INFINITE),
-                              "Acquire keyed mutex"));
-    }
     return {};
 }
 
@@ -86,12 +74,6 @@
                     "Required feature (%s) is missing.",
                     wgpu::FeatureName::SharedFenceDXGISharedHandle);
 
-    // Release keyed mutex for the last access.
-    if (mDXGIKeyedMutex && !HasWriteAccess() && !HasExclusiveReadAccess() &&
-        GetReadAccessCount() == 0) {
-        mDXGIKeyedMutex->ReleaseSync(kDXGIKeyedMutexAcquireKey);
-    }
-
     Ref<SharedFence> sharedFence;
     DAWN_TRY_ASSIGN(sharedFence, ToBackend(GetDevice()->GetQueue())->GetOrCreateSharedFence());
 
diff --git a/src/dawn/native/d3d/SharedTextureMemoryD3D.h b/src/dawn/native/d3d/SharedTextureMemoryD3D.h
index 3a2c305..cfdf697 100644
--- a/src/dawn/native/d3d/SharedTextureMemoryD3D.h
+++ b/src/dawn/native/d3d/SharedTextureMemoryD3D.h
@@ -39,21 +39,12 @@
   protected:
     SharedTextureMemory(Device* device,
                         const char* label,
-                        SharedTextureMemoryProperties properties,
-                        IUnknown* resource);
+                        SharedTextureMemoryProperties properties);
 
-  protected:
     MaybeError BeginAccessImpl(TextureBase* texture,
                                const UnpackedPtr<BeginAccessDescriptor>& descriptor) override;
     ResultOrError<FenceAndSignalValue> EndAccessImpl(TextureBase* texture,
                                                      UnpackedPtr<EndAccessState>& state) override;
-
-  private:
-    // If the resource has IDXGIKeyedMutex interface, it will be used for synchronization.
-    // TODO(dawn:1906): remove the mDXGIKeyedMutex when it is not used in chrome.
-    ComPtr<IDXGIKeyedMutex> mDXGIKeyedMutex;
-    // Chrome uses 0 as acquire key.
-    static constexpr UINT64 kDXGIKeyedMutexAcquireKey = 0;
 };
 
 }  // namespace dawn::native::d3d
diff --git a/src/dawn/native/d3d/d3d_platform.h b/src/dawn/native/d3d/d3d_platform.h
index e1944ed..06b63ac 100644
--- a/src/dawn/native/d3d/d3d_platform.h
+++ b/src/dawn/native/d3d/d3d_platform.h
@@ -44,6 +44,17 @@
 #include <DXProgrammableCapture.h>  // NOLINT(build/include_order)
 #include <dxgidebug.h>              // NOLINT(build/include_order)
 
+#include <functional>  // NOLINT(build/include_order)
+#include <utility>     // NOLINT(build/include_order)
+
 using Microsoft::WRL::ComPtr;
+template <typename T>
+struct std::hash<ComPtr<T>> {
+    std::size_t operator()(const ComPtr<T>& v) const noexcept { return std::hash<T*>{}(v.Get()); }
+};
+template <typename T, typename H>
+H AbslHashValue(H state, const ComPtr<T>& v) {
+    return H::combine(std::move(state), v.Get());
+}
 
 #endif  // SRC_DAWN_NATIVE_D3D_D3D_PLATFORM_H_
diff --git a/src/dawn/native/d3d11/CommandBufferD3D11.cpp b/src/dawn/native/d3d11/CommandBufferD3D11.cpp
index 5f6e03f..eaa105f 100644
--- a/src/dawn/native/d3d11/CommandBufferD3D11.cpp
+++ b/src/dawn/native/d3d11/CommandBufferD3D11.cpp
@@ -245,6 +245,9 @@
                 }
                 for (const SyncScopeResourceUsage& scope :
                      GetResourceUsages().computePasses[nextComputePassNumber].dispatchUsages) {
+                    for (TextureBase* texture : scope.textures) {
+                        DAWN_TRY(ToBackend(texture)->SynchronizeTextureBeforeUse(commandContext));
+                    }
                     DAWN_TRY(LazyClearSyncScope(scope));
                 }
                 DAWN_TRY(ExecuteComputePass(commandContext));
diff --git a/src/dawn/native/d3d11/CommandRecordingContextD3D11.cpp b/src/dawn/native/d3d11/CommandRecordingContextD3D11.cpp
index 7d79769..d062143 100644
--- a/src/dawn/native/d3d11/CommandRecordingContextD3D11.cpp
+++ b/src/dawn/native/d3d11/CommandRecordingContextD3D11.cpp
@@ -30,6 +30,7 @@
 #include <string>
 #include <utility>
 
+#include "dawn/native/D3DBackend.h"
 #include "dawn/native/d3d/D3DError.h"
 #include "dawn/native/d3d11/BufferD3D11.h"
 #include "dawn/native/d3d11/DeviceD3D11.h"
@@ -129,6 +130,16 @@
     return {};
 }
 
+MaybeError ScopedCommandRecordingContext::AcquireKeyedMutex(
+    ComPtr<IDXGIKeyedMutex> dxgikeyedMutex) const {
+    if (!Get()->mAcquiredKeyedMutexes.contains(dxgikeyedMutex)) {
+        DAWN_TRY(CheckHRESULT(dxgikeyedMutex->AcquireSync(d3d::kDXGIKeyedMutexAcquireKey, INFINITE),
+                              "Failed to acquire keyed mutex for external image"));
+        Get()->mAcquiredKeyedMutexes.emplace(std::move(dxgikeyedMutex));
+    }
+    return {};
+}
+
 ScopedSwapStateCommandRecordingContext::ScopedSwapStateCommandRecordingContext(
     CommandRecordingContextGuard&& guard)
     : ScopedCommandRecordingContext(std::move(guard)),
@@ -201,6 +212,27 @@
     return {};
 }
 
+void CommandRecordingContext::Destroy() {
+    DAWN_ASSERT(mDevice->IsLockedByCurrentThreadIfNeeded());
+    mIsOpen = false;
+    mUniformBuffer = nullptr;
+    mDevice = nullptr;
+
+    if (mD3D11DeviceContext4) {
+        ID3D11Buffer* nullBuffer = nullptr;
+        mD3D11DeviceContext4->VSSetConstantBuffers(PipelineLayout::kReservedConstantBufferSlot, 1,
+                                                   &nullBuffer);
+        mD3D11DeviceContext4->CSSetConstantBuffers(PipelineLayout::kReservedConstantBufferSlot, 1,
+                                                   &nullBuffer);
+    }
+
+    ReleaseKeyedMutexes();
+
+    mD3D11DeviceContextState = nullptr;
+    mD3D11DeviceContext4 = nullptr;
+    mD3D11Device = nullptr;
+}
+
 // static
 ResultOrError<Ref<BufferBase>> CommandRecordingContext::CreateInternalUniformBuffer(
     DeviceBase* device) {
@@ -229,21 +261,11 @@
                                                &bufferPtr);
 }
 
-void CommandRecordingContext::Release() {
-    if (mIsOpen) {
-        DAWN_ASSERT(mDevice->IsLockedByCurrentThreadIfNeeded());
-        mIsOpen = false;
-        mUniformBuffer = nullptr;
-        mDevice = nullptr;
-        ID3D11Buffer* nullBuffer = nullptr;
-        mD3D11DeviceContext4->VSSetConstantBuffers(PipelineLayout::kReservedConstantBufferSlot, 1,
-                                                   &nullBuffer);
-        mD3D11DeviceContext4->CSSetConstantBuffers(PipelineLayout::kReservedConstantBufferSlot, 1,
-                                                   &nullBuffer);
-        mD3D11DeviceContextState = nullptr;
-        mD3D11DeviceContext4 = nullptr;
-        mD3D11Device = nullptr;
+void CommandRecordingContext::ReleaseKeyedMutexes() {
+    for (auto& dxgikeyedMutex : mAcquiredKeyedMutexes) {
+        dxgikeyedMutex->ReleaseSync(d3d::kDXGIKeyedMutexAcquireKey);
     }
+    mAcquiredKeyedMutexes.clear();
 }
 
 }  // namespace dawn::native::d3d11
diff --git a/src/dawn/native/d3d11/CommandRecordingContextD3D11.h b/src/dawn/native/d3d11/CommandRecordingContextD3D11.h
index b56bdca..f0e1a05 100644
--- a/src/dawn/native/d3d11/CommandRecordingContextD3D11.h
+++ b/src/dawn/native/d3d11/CommandRecordingContextD3D11.h
@@ -28,6 +28,7 @@
 #ifndef SRC_DAWN_NATIVE_D3D11_COMMANDRECORDINGCONTEXT_D3D11_H_
 #define SRC_DAWN_NATIVE_D3D11_COMMANDRECORDINGCONTEXT_D3D11_H_
 
+#include "absl/container/flat_hash_set.h"
 #include "dawn/common/MutexProtected.h"
 #include "dawn/common/NonCopyable.h"
 #include "dawn/common/Ref.h"
@@ -77,11 +78,12 @@
                                      ::dawn::detail::MutexProtectedTraits<CommandRecordingContext>>;
 
     MaybeError Initialize(Device* device);
+    void Destroy();
 
     static ResultOrError<Ref<BufferBase>> CreateInternalUniformBuffer(DeviceBase* device);
     void SetInternalUniformBuffer(Ref<BufferBase> uniformBuffer);
 
-    void Release();
+    void ReleaseKeyedMutexes();
 
   private:
     template <typename Ctx, typename Traits>
@@ -104,6 +106,8 @@
     std::array<uint32_t, kMaxNumBuiltinElements> mUniformBufferData;
     bool mUniformBufferDirty = true;
 
+    absl::flat_hash_set<ComPtr<IDXGIKeyedMutex>> mAcquiredKeyedMutexes;
+
     Ref<Device> mDevice;
 };
 
@@ -149,6 +153,8 @@
     // Write the built-in variable value to the uniform buffer.
     void WriteUniformBuffer(uint32_t offset, uint32_t element) const;
     MaybeError FlushUniformBuffer() const;
+
+    MaybeError AcquireKeyedMutex(ComPtr<IDXGIKeyedMutex> dxgikeyedMutex) const;
 };
 
 // For using ID3D11DeviceContext directly. It swaps and resets ID3DDeviceContextState of
diff --git a/src/dawn/native/d3d11/DeviceD3D11.cpp b/src/dawn/native/d3d11/DeviceD3D11.cpp
index be3bc98..7c8de7b 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/DeviceD3D11.cpp
@@ -481,14 +481,15 @@
 
 Ref<TextureBase> Device::CreateD3DExternalTexture(const UnpackedPtr<TextureDescriptor>& descriptor,
                                                   ComPtr<IUnknown> d3dTexture,
+                                                  ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex,
                                                   std::vector<FenceAndSignalValue> waitFences,
                                                   bool isSwapChainTexture,
                                                   bool isInitialized) {
     Ref<Texture> dawnTexture;
-    if (ConsumedError(
-            Texture::CreateExternalImage(this, descriptor, std::move(d3dTexture),
-                                         std::move(waitFences), isSwapChainTexture, isInitialized),
-            &dawnTexture)) {
+    if (ConsumedError(Texture::CreateExternalImage(this, descriptor, std::move(d3dTexture),
+                                                   std::move(dxgiKeyedMutex), std::move(waitFences),
+                                                   isSwapChainTexture, isInitialized),
+                      &dawnTexture)) {
         return nullptr;
     }
     return {dawnTexture};
diff --git a/src/dawn/native/d3d11/DeviceD3D11.h b/src/dawn/native/d3d11/DeviceD3D11.h
index 67795f8..7ce600f 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.h
+++ b/src/dawn/native/d3d11/DeviceD3D11.h
@@ -57,6 +57,7 @@
     void ReferenceUntilUnused(ComPtr<IUnknown> object);
     Ref<TextureBase> CreateD3DExternalTexture(const UnpackedPtr<TextureDescriptor>& descriptor,
                                               ComPtr<IUnknown> d3dTexture,
+                                              ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex,
                                               std::vector<FenceAndSignalValue> waitFences,
                                               bool isSwapChainTexture,
                                               bool isInitialized) override;
diff --git a/src/dawn/native/d3d11/QueueD3D11.cpp b/src/dawn/native/d3d11/QueueD3D11.cpp
index fff390d..14dec3e 100644
--- a/src/dawn/native/d3d11/QueueD3D11.cpp
+++ b/src/dawn/native/d3d11/QueueD3D11.cpp
@@ -91,7 +91,7 @@
     mSharedFence = nullptr;
 
     mPendingCommands.Use([&](auto pendingCommands) {
-        pendingCommands->Release();
+        pendingCommands->Destroy();
         mPendingCommandsNeedSubmit.store(false, std::memory_order_release);
     });
 }
@@ -124,10 +124,14 @@
 }
 
 MaybeError Queue::SubmitPendingCommands() {
-    if (!mPendingCommandsNeedSubmit.exchange(false, std::memory_order_acq_rel)) {
-        return {};
+    bool needsSubmit = mPendingCommands.Use([&](auto pendingCommands) {
+        pendingCommands->ReleaseKeyedMutexes();
+        return mPendingCommandsNeedSubmit.exchange(false, std::memory_order_acq_rel);
+    });
+    if (needsSubmit) {
+        return NextSerial();
     }
-    return NextSerial();
+    return {};
 }
 
 MaybeError Queue::SubmitImpl(uint32_t commandCount, CommandBufferBase* const* commands) {
diff --git a/src/dawn/native/d3d11/SharedTextureMemoryD3D11.cpp b/src/dawn/native/d3d11/SharedTextureMemoryD3D11.cpp
index de12ed9..0dfffb6 100644
--- a/src/dawn/native/d3d11/SharedTextureMemoryD3D11.cpp
+++ b/src/dawn/native/d3d11/SharedTextureMemoryD3D11.cpp
@@ -145,10 +145,12 @@
                                          const char* label,
                                          SharedTextureMemoryProperties properties,
                                          ComPtr<ID3D11Resource> resource)
-    : d3d::SharedTextureMemory(device, label, properties, resource.Get()),
-      mResource(std::move(resource)) {}
+    : d3d::SharedTextureMemory(device, label, properties), mResource(std::move(resource)) {
+    mResource.As(&mKeyedMutex);
+}
 
 void SharedTextureMemory::DestroyImpl() {
+    mKeyedMutex = nullptr;
     mResource = nullptr;
 }
 
@@ -156,6 +158,10 @@
     return mResource.Get();
 }
 
+IDXGIKeyedMutex* SharedTextureMemory::GetKeyedMutex() const {
+    return mKeyedMutex.Get();
+}
+
 ResultOrError<Ref<TextureBase>> SharedTextureMemory::CreateTextureImpl(
     const UnpackedPtr<TextureDescriptor>& descriptor) {
     return Texture::CreateFromSharedTextureMemory(this, descriptor);
diff --git a/src/dawn/native/d3d11/SharedTextureMemoryD3D11.h b/src/dawn/native/d3d11/SharedTextureMemoryD3D11.h
index 7b310f0..576b248 100644
--- a/src/dawn/native/d3d11/SharedTextureMemoryD3D11.h
+++ b/src/dawn/native/d3d11/SharedTextureMemoryD3D11.h
@@ -51,6 +51,8 @@
 
     ID3D11Resource* GetD3DResource() const;
 
+    IDXGIKeyedMutex* GetKeyedMutex() const;
+
   private:
     SharedTextureMemory(Device* device,
                         const char* label,
@@ -63,6 +65,7 @@
         const UnpackedPtr<TextureDescriptor>& descriptor) override;
 
     ComPtr<ID3D11Resource> mResource;
+    ComPtr<IDXGIKeyedMutex> mKeyedMutex;
 };
 
 }  // namespace dawn::native::d3d11
diff --git a/src/dawn/native/d3d11/TextureD3D11.cpp b/src/dawn/native/d3d11/TextureD3D11.cpp
index ba3dbb7..d4ea2aa 100644
--- a/src/dawn/native/d3d11/TextureD3D11.cpp
+++ b/src/dawn/native/d3d11/TextureD3D11.cpp
@@ -43,6 +43,7 @@
 #include "dawn/native/ToBackend.h"
 #include "dawn/native/d3d/D3DError.h"
 #include "dawn/native/d3d/UtilsD3D.h"
+#include "dawn/native/d3d11/CommandRecordingContextD3D11.h"
 #include "dawn/native/d3d11/DeviceD3D11.h"
 #include "dawn/native/d3d11/Forward.h"
 #include "dawn/native/d3d11/QueueD3D11.h"
@@ -229,12 +230,22 @@
     Device* device,
     const UnpackedPtr<TextureDescriptor>& descriptor,
     ComPtr<IUnknown> d3dTexture,
+    ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex,
     std::vector<FenceAndSignalValue> waitFences,
     bool isSwapChainTexture,
     bool isInitialized) {
     Ref<Texture> dawnTexture = AcquireRef(new Texture(device, descriptor, Kind::Normal));
-    DAWN_TRY(dawnTexture->InitializeAsExternalTexture(std::move(d3dTexture), std::move(waitFences),
-                                                      isSwapChainTexture));
+    DAWN_TRY(
+        dawnTexture->InitializeAsExternalTexture(std::move(d3dTexture), std::move(dxgiKeyedMutex)));
+
+    auto commandContext =
+        ToBackend(device->GetQueue())
+            ->GetScopedPendingCommandContext(ExecutionQueueBase::SubmitMode::Normal);
+    for (const auto& fence : waitFences) {
+        DAWN_TRY(CheckHRESULT(
+            commandContext.Wait(ToBackend(fence.object)->GetD3DFence(), fence.signaledValue),
+            "ID3D11DeviceContext4::Wait"));
+    }
 
     // Importing a multi-planar format must be initialized. This is required because
     // a shared multi-planar format cannot be initialized by Dawn.
@@ -254,7 +265,8 @@
     const UnpackedPtr<TextureDescriptor>& descriptor) {
     Device* device = ToBackend(memory->GetDevice());
     Ref<Texture> texture = AcquireRef(new Texture(device, descriptor, Kind::Normal));
-    DAWN_TRY(texture->InitializeAsExternalTexture(memory->GetD3DResource(), {}, false));
+    DAWN_TRY(
+        texture->InitializeAsExternalTexture(memory->GetD3DResource(), memory->GetKeyedMutex()));
     texture->mSharedTextureMemoryContents = memory->GetContents();
     return texture;
 }
@@ -367,19 +379,11 @@
 }
 
 MaybeError Texture::InitializeAsExternalTexture(ComPtr<IUnknown> d3dTexture,
-                                                std::vector<FenceAndSignalValue> waitFences,
-                                                bool isSwapChainTexture) {
+                                                ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex) {
     ComPtr<ID3D11Resource> d3d11Texture;
     DAWN_TRY(CheckHRESULT(d3dTexture.As(&d3d11Texture), "Query ID3D11Resource from IUnknown"));
-
-    auto commandContext = ToBackend(GetDevice()->GetQueue())
-                              ->GetScopedPendingCommandContext(QueueBase::SubmitMode::Normal);
-    for (const auto& fence : waitFences) {
-        DAWN_TRY(CheckHRESULT(
-            commandContext.Wait(ToBackend(fence.object)->GetD3DFence(), fence.signaledValue),
-            "ID3D11DeviceContext4::Wait"));
-    }
     mD3d11Resource = std::move(d3d11Texture);
+    mDxgiKeyedMutex = std::move(dxgiKeyedMutex);
     SetLabelHelper("Dawn_ExternalTexture");
     return {};
 }
@@ -498,17 +502,19 @@
 
 MaybeError Texture::SynchronizeTextureBeforeUse(
     const ScopedCommandRecordingContext* commandContext) {
-    if (SharedTextureMemoryContents* contents = GetSharedTextureMemoryContents()) {
+    if (auto* contents = GetSharedTextureMemoryContents()) {
         SharedTextureMemoryBase::PendingFenceList fences;
         contents->AcquirePendingFences(&fences);
         contents->SetLastUsageSerial(GetDevice()->GetQueue()->GetPendingCommandSerial());
-
-        for (auto& fence : fences) {
+        for (const auto& fence : fences) {
             DAWN_TRY(CheckHRESULT(
                 commandContext->Wait(ToBackend(fence.object)->GetD3DFence(), fence.signaledValue),
                 "ID3D11DeviceContext4::Wait"));
         }
     }
+    if (mDxgiKeyedMutex) {
+        DAWN_TRY(commandContext->AcquireKeyedMutex(mDxgiKeyedMutex));
+    }
     mLastUsageSerial = GetDevice()->GetQueue()->GetPendingCommandSerial();
     return {};
 }
diff --git a/src/dawn/native/d3d11/TextureD3D11.h b/src/dawn/native/d3d11/TextureD3D11.h
index 99854fc..69d5ad8 100644
--- a/src/dawn/native/d3d11/TextureD3D11.h
+++ b/src/dawn/native/d3d11/TextureD3D11.h
@@ -67,6 +67,7 @@
         Device* device,
         const UnpackedPtr<TextureDescriptor>& descriptor,
         ComPtr<IUnknown> d3dTexture,
+        ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex,
         std::vector<FenceAndSignalValue> waitFences,
         bool isSwapChainTexture,
         bool isInitialized);
@@ -141,8 +142,7 @@
     MaybeError InitializeAsInternalTexture();
     MaybeError InitializeAsSwapChainTexture(ComPtr<ID3D11Resource> d3d11Texture);
     MaybeError InitializeAsExternalTexture(ComPtr<IUnknown> d3dTexture,
-                                           std::vector<FenceAndSignalValue> waitFences,
-                                           bool isSwapChainTexture);
+                                           ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex);
     void SetLabelHelper(const char* prefix);
 
     // Dawn API
@@ -195,6 +195,7 @@
 
     const Kind mKind = Kind::Normal;
     ComPtr<ID3D11Resource> mD3d11Resource;
+    ComPtr<IDXGIKeyedMutex> mDxgiKeyedMutex;
 
     // TODO(crbug.com/1515640): Remove this once Chromium has migrated to SharedTextureMemory.
     std::optional<ExecutionSerial> mLastUsageSerial;
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index 014f1d9..746e308 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -544,9 +544,12 @@
 
 Ref<TextureBase> Device::CreateD3DExternalTexture(const UnpackedPtr<TextureDescriptor>& descriptor,
                                                   ComPtr<IUnknown> d3dTexture,
+                                                  ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex,
                                                   std::vector<FenceAndSignalValue> waitFences,
                                                   bool isSwapChainTexture,
                                                   bool isInitialized) {
+    // TODO(sunnyps): Reintroduce keyed mutex support.
+    DAWN_ASSERT(dxgiKeyedMutex == nullptr);
     Ref<Texture> dawnTexture;
     if (ConsumedError(
             Texture::CreateExternalImage(this, descriptor, std::move(d3dTexture),
diff --git a/src/dawn/native/d3d12/DeviceD3D12.h b/src/dawn/native/d3d12/DeviceD3D12.h
index 12f37fa..b0ff361 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.h
+++ b/src/dawn/native/d3d12/DeviceD3D12.h
@@ -146,6 +146,7 @@
 
     Ref<TextureBase> CreateD3DExternalTexture(const UnpackedPtr<TextureDescriptor>& descriptor,
                                               ComPtr<IUnknown> d3dTexture,
+                                              ComPtr<IDXGIKeyedMutex> dxgiKeyedMutex,
                                               std::vector<FenceAndSignalValue> waitFences,
                                               bool isSwapChainTexture,
                                               bool isInitialized) override;
diff --git a/src/dawn/native/d3d12/SharedTextureMemoryD3D12.cpp b/src/dawn/native/d3d12/SharedTextureMemoryD3D12.cpp
index b767915..6c6e035 100644
--- a/src/dawn/native/d3d12/SharedTextureMemoryD3D12.cpp
+++ b/src/dawn/native/d3d12/SharedTextureMemoryD3D12.cpp
@@ -101,8 +101,7 @@
                                          const char* label,
                                          SharedTextureMemoryProperties properties,
                                          ComPtr<ID3D12Resource> resource)
-    : d3d::SharedTextureMemory(device, label, properties, resource.Get()),
-      mResource(std::move(resource)) {}
+    : d3d::SharedTextureMemory(device, label, properties), mResource(std::move(resource)) {}
 
 void SharedTextureMemory::DestroyImpl() {
     ToBackend(GetDevice())->ReferenceUntilUnused(std::move(mResource));