d3d: Change D3D references to PhysicalDevice to WeakRef instead of Ref

BackendD3D holds strong references to PhysicalDeviceD3D11/12, which
are created when enumerating adapters. Once adapters are enumerated,
these are kept alive indefinitely. Upon creation, PhysicalDeviceD3D
has a reference to the D3D11/12Device, which means that the
corresponding physical adapter is kept powered on indefinitely.

This CL changes the strong Ref to a WeakRef and relies on the caller
to keep a strong reference to the PhysicalDevice to keep it alive.
All other unused devices are cleaned up, which releases the unused
D3DDevices.

Bug: 342299153
Change-Id: I4ff6979abb175f9b737fb3ede4b26334d858d6a4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/189581
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Patrick To <patrto@microsoft.com>
diff --git a/src/dawn/native/PhysicalDevice.h b/src/dawn/native/PhysicalDevice.h
index dfc5bd4..301b7d3 100644
--- a/src/dawn/native/PhysicalDevice.h
+++ b/src/dawn/native/PhysicalDevice.h
@@ -36,6 +36,7 @@
 #include "dawn/common/GPUInfo.h"
 #include "dawn/common/Ref.h"
 #include "dawn/common/RefCounted.h"
+#include "dawn/common/WeakRefSupport.h"
 #include "dawn/common/ityp_span.h"
 #include "dawn/native/Device.h"
 #include "dawn/native/Error.h"
@@ -67,7 +68,7 @@
     std::string errorMessage;
 };
 
-class PhysicalDeviceBase : public RefCounted {
+class PhysicalDeviceBase : public RefCounted, public WeakRefSupport<PhysicalDeviceBase> {
   public:
     explicit PhysicalDeviceBase(wgpu::BackendType backend);
     ~PhysicalDeviceBase() override;
diff --git a/src/dawn/native/d3d/BackendD3D.cpp b/src/dawn/native/d3d/BackendD3D.cpp
index 478d60f..29291d5 100644
--- a/src/dawn/native/d3d/BackendD3D.cpp
+++ b/src/dawn/native/d3d/BackendD3D.cpp
@@ -255,19 +255,17 @@
 }
 
 ResultOrError<Ref<PhysicalDeviceBase>> Backend::GetOrCreatePhysicalDeviceFromLUID(LUID luid) {
-    auto it = mPhysicalDevices.find(luid);
-    if (it != mPhysicalDevices.end()) {
-        // If we've already discovered this physical device, return it.
-        return it->second;
+    Ref<PhysicalDeviceBase> physicalDevice = FindPhysicalDevice(luid);
+    if (physicalDevice == nullptr) {
+        ComPtr<IDXGIAdapter1> dxgiAdapter = nullptr;
+        DAWN_TRY(CheckHRESULT(GetFactory()->EnumAdapterByLuid(luid, IID_PPV_ARGS(&dxgiAdapter)),
+                              "EnumAdapterByLuid"));
+
+        DAWN_TRY_ASSIGN(physicalDevice, CreatePhysicalDeviceFromIDXGIAdapter(dxgiAdapter));
+        // The LUID may already exist in the map if the previous WeakRef has been invalidated. In
+        // that case, `insert_or_assign` will replace the old device with the new one.
+        mPhysicalDevices.insert_or_assign(luid, GetWeakRef(physicalDevice));
     }
-
-    ComPtr<IDXGIAdapter1> dxgiAdapter = nullptr;
-    DAWN_TRY(CheckHRESULT(GetFactory()->EnumAdapterByLuid(luid, IID_PPV_ARGS(&dxgiAdapter)),
-                          "EnumAdapterByLuid"));
-
-    Ref<PhysicalDeviceBase> physicalDevice;
-    DAWN_TRY_ASSIGN(physicalDevice, CreatePhysicalDeviceFromIDXGIAdapter(dxgiAdapter));
-    mPhysicalDevices.emplace(luid, physicalDevice);
     return physicalDevice;
 }
 
@@ -276,18 +274,26 @@
     DXGI_ADAPTER_DESC desc;
     DAWN_TRY(CheckHRESULT(dxgiAdapter->GetDesc(&desc), "IDXGIAdapter::GetDesc"));
 
-    auto it = mPhysicalDevices.find(desc.AdapterLuid);
-    if (it != mPhysicalDevices.end()) {
-        // If we've already discovered this physical device, return it.
-        return it->second;
+    Ref<PhysicalDeviceBase> physicalDevice = FindPhysicalDevice(desc.AdapterLuid);
+    if (physicalDevice == nullptr) {
+        DAWN_TRY_ASSIGN(physicalDevice, CreatePhysicalDeviceFromIDXGIAdapter(dxgiAdapter));
+        // The LUID may already exist in the map if the previous WeakRef has been invalidated. In
+        // that case, `insert_or_assign` will replace the old device with the new one.
+        mPhysicalDevices.insert_or_assign(desc.AdapterLuid, GetWeakRef(physicalDevice));
     }
-
-    Ref<PhysicalDeviceBase> physicalDevice;
-    DAWN_TRY_ASSIGN(physicalDevice, CreatePhysicalDeviceFromIDXGIAdapter(dxgiAdapter));
-    mPhysicalDevices.emplace(desc.AdapterLuid, physicalDevice);
     return physicalDevice;
 }
 
+Ref<PhysicalDeviceBase> Backend::FindPhysicalDevice(const LUID& luid) {
+    auto it = mPhysicalDevices.find(luid);
+    if (it == mPhysicalDevices.end()) {
+        return nullptr;
+    }
+    // If we've already discovered this physical device, try to Promote the WeakRef to a Ref. If the
+    // WeakRef has been invalidated, nullptr is returned.
+    return it->second.Promote();
+}
+
 std::vector<Ref<PhysicalDeviceBase>> Backend::DiscoverPhysicalDevices(
     const UnpackedPtr<RequestAdapterOptions>& options) {
     if (options->forceFallbackAdapter) {
diff --git a/src/dawn/native/d3d/BackendD3D.h b/src/dawn/native/d3d/BackendD3D.h
index 7ae6fdb..295b699 100644
--- a/src/dawn/native/d3d/BackendD3D.h
+++ b/src/dawn/native/d3d/BackendD3D.h
@@ -102,6 +102,7 @@
     ResultOrError<Ref<PhysicalDeviceBase>> GetOrCreatePhysicalDeviceFromLUID(LUID luid);
     ResultOrError<Ref<PhysicalDeviceBase>> GetOrCreatePhysicalDeviceFromIDXGIAdapter(
         ComPtr<IDXGIAdapter> dxgiAdapter);
+    Ref<PhysicalDeviceBase> FindPhysicalDevice(const LUID& luid);
 
     // Acquiring DXC version information and store the result in mDxcVersionInfo. This function
     // should be called only once, during startup in `Initialize`.
@@ -133,9 +134,13 @@
     };
 
     // Map of LUID to physical device.
-    // The LUID is guaranteed to be uniquely identify an adapter on the local
-    // machine until restart.
-    absl::flat_hash_map<LUID, Ref<PhysicalDeviceBase>, LUIDHashFunc, LUIDEqualFunc>
+    // The LUID is guaranteed to be uniquely identify an adapter on the local machine until restart.
+    // A WeakRef prevents the PhysicalDeviceBase (and its D3D Device) from being kept alive if there
+    // are no longer any external references. Any references to a D3D Device keeps its corresponding
+    // physical adapter powered on. Since `DiscoverPhysicalDevices` may enumerate and add all
+    // available adapters, we should release the ones that the caller does not take a strong
+    // reference on. Otherwise, all adapters on the system will be kept powered on indefinitely.
+    absl::flat_hash_map<LUID, WeakRef<PhysicalDeviceBase>, LUIDHashFunc, LUIDEqualFunc>
         mPhysicalDevices;
 };