Make Adapter and Instance lifetimes more robust

Previously, we would get a use-after-free if you dropped the instance
before an adapter created from it. This CL fixes up the lifetimes
such that Device refs Adapter refs Instance. Instance uses a
cycle-breaking refcount so that it releases internal refs to its
adapters when the last external ref is dropped.

Bug: none
Change-Id: I5304ec86f425247d4c45ca342fda393cc19689e3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99820
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/include/dawn/native/DawnNative.h b/include/dawn/native/DawnNative.h
index 9dab397..0d0df14 100644
--- a/include/dawn/native/DawnNative.h
+++ b/include/dawn/native/DawnNative.h
@@ -190,6 +190,9 @@
 // Backdoor to get the number of deprecation warnings for testing
 DAWN_NATIVE_EXPORT size_t GetDeprecationWarningCountForTesting(WGPUDevice device);
 
+// Backdoor to get the number of adapters an instance knows about for testing
+DAWN_NATIVE_EXPORT size_t GetAdapterCountForTesting(WGPUInstance instance);
+
 //  Query if texture has been initialized
 DAWN_NATIVE_EXPORT bool IsTextureSubresourceInitialized(
     WGPUTexture texture,
diff --git a/src/dawn/native/Adapter.cpp b/src/dawn/native/Adapter.cpp
index 68e0420..7415312 100644
--- a/src/dawn/native/Adapter.cpp
+++ b/src/dawn/native/Adapter.cpp
@@ -31,6 +31,8 @@
     mSupportedFeatures.EnableFeature(Feature::DawnInternalUsages);
 }
 
+AdapterBase::~AdapterBase() = default;
+
 MaybeError AdapterBase::Initialize() {
     DAWN_TRY_CONTEXT(InitializeImpl(), "initializing adapter (backend=%s)", mBackend);
     InitializeVendorArchitectureImpl();
@@ -157,7 +159,7 @@
 }
 
 InstanceBase* AdapterBase::GetInstance() const {
-    return mInstance;
+    return mInstance.Get();
 }
 
 FeaturesSet AdapterBase::GetSupportedFeatures() const {
diff --git a/src/dawn/native/Adapter.h b/src/dawn/native/Adapter.h
index f979b35..6b6448f 100644
--- a/src/dawn/native/Adapter.h
+++ b/src/dawn/native/Adapter.h
@@ -33,7 +33,7 @@
 class AdapterBase : public RefCounted {
   public:
     AdapterBase(InstanceBase* instance, wgpu::BackendType backend);
-    ~AdapterBase() override = default;
+    ~AdapterBase() override;
 
     MaybeError Initialize();
 
@@ -90,7 +90,7 @@
     ResultOrError<Ref<DeviceBase>> CreateDeviceInternal(const DeviceDescriptor* descriptor);
 
     virtual MaybeError ResetInternalDeviceForTestingImpl();
-    InstanceBase* mInstance = nullptr;
+    Ref<InstanceBase> mInstance;
     wgpu::BackendType mBackend;
     CombinedLimits mLimits;
     bool mUseTieredLimits = false;
diff --git a/src/dawn/native/DawnNative.cpp b/src/dawn/native/DawnNative.cpp
index 5cb93e2..89c48ac 100644
--- a/src/dawn/native/DawnNative.cpp
+++ b/src/dawn/native/DawnNative.cpp
@@ -191,7 +191,7 @@
 
 Instance::~Instance() {
     if (mImpl != nullptr) {
-        mImpl->Release();
+        mImpl->APIRelease();
         mImpl = nullptr;
     }
 }
@@ -256,6 +256,10 @@
     return FromAPI(device)->GetDeprecationWarningCountForTesting();
 }
 
+size_t GetAdapterCountForTesting(WGPUInstance instance) {
+    return FromAPI(instance)->GetAdapters().size();
+}
+
 bool IsTextureSubresourceInitialized(WGPUTexture texture,
                                      uint32_t baseMipLevel,
                                      uint32_t levelCount,
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index f737e47..a83d185 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -171,8 +171,8 @@
 // DeviceBase
 
 DeviceBase::DeviceBase(AdapterBase* adapter, const DeviceDescriptor* descriptor)
-    : mInstance(adapter->GetInstance()), mAdapter(adapter), mNextPipelineCompatibilityToken(1) {
-    mInstance->IncrementDeviceCountForTesting();
+    : mAdapter(adapter), mNextPipelineCompatibilityToken(1) {
+    mAdapter->GetInstance()->IncrementDeviceCountForTesting();
     ASSERT(descriptor != nullptr);
 
     AdapterProperties adapterProperties;
@@ -221,9 +221,9 @@
     // We need to explicitly release the Queue before we complete the destructor so that the
     // Queue does not get destroyed after the Device.
     mQueue = nullptr;
-    // mInstance is not set for mock test devices.
-    if (mInstance != nullptr) {
-        mInstance->DecrementDeviceCountForTesting();
+    // mAdapter is not set for mock test devices.
+    if (mAdapter != nullptr) {
+        mAdapter->GetInstance()->DecrementDeviceCountForTesting();
     }
 }
 
@@ -628,7 +628,7 @@
     // generate cache keys. We can lift the dependency once we also cache frontend parsing,
     // transformations, and reflection.
     if (IsToggleEnabled(Toggle::EnableBlobCache)) {
-        return mInstance->GetBlobCache();
+        return mAdapter->GetInstance()->GetBlobCache();
     }
 #endif
     return nullptr;
@@ -696,7 +696,7 @@
 }
 
 AdapterBase* DeviceBase::GetAdapter() const {
-    return mAdapter;
+    return mAdapter.Get();
 }
 
 dawn::platform::Platform* DeviceBase::GetPlatform() const {
@@ -1286,7 +1286,7 @@
 
 AdapterBase* DeviceBase::APIGetAdapter() {
     mAdapter->Reference();
-    return mAdapter;
+    return mAdapter.Get();
 }
 
 QueueBase* DeviceBase::APIGetQueue() {
diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h
index b8b145b..9e4fe03 100644
--- a/src/dawn/native/Device.h
+++ b/src/dawn/native/Device.h
@@ -527,12 +527,7 @@
 
     std::unique_ptr<ErrorScopeStack> mErrorScopeStack;
 
-    // The Device keeps a ref to the Instance so that any live Device keeps the Instance alive.
-    // The Instance shouldn't need to ref child objects so this shouldn't introduce ref cycles.
-    // The Device keeps a simple pointer to the Adapter because the Adapter is owned by the
-    // Instance.
-    Ref<InstanceBase> mInstance;
-    AdapterBase* mAdapter = nullptr;
+    Ref<AdapterBase> mAdapter;
 
     // The object caches aren't exposed in the header as they would require a lot of
     // additional includes.
diff --git a/src/dawn/native/Instance.cpp b/src/dawn/native/Instance.cpp
index afe9c1d..a582199 100644
--- a/src/dawn/native/Instance.cpp
+++ b/src/dawn/native/Instance.cpp
@@ -125,6 +125,17 @@
 
 InstanceBase::~InstanceBase() = default;
 
+void InstanceBase::WillDropLastExternalRef() {
+    // InstanceBase uses RefCountedWithExternalCount to break refcycles.
+    //
+    // InstanceBase holds Refs to AdapterBases it has discovered, which hold Refs back to the
+    // InstanceBase.
+    // In order to break this cycle and prevent leaks, when the application drops the last external
+    // ref and WillDropLastExternalRef is called, the instance clears out any member refs to
+    // adapters that hold back-refs to the instance - thus breaking any reference cycles.
+    mAdapters.clear();
+}
+
 // TODO(crbug.com/dawn/832): make the platform an initialization parameter of the instance.
 MaybeError InstanceBase::Initialize(const InstanceDescriptor* descriptor) {
     DAWN_TRY(ValidateSingleSType(descriptor->nextInChain, wgpu::SType::DawnInstanceDescriptor));
diff --git a/src/dawn/native/Instance.h b/src/dawn/native/Instance.h
index df22f27..0589061 100644
--- a/src/dawn/native/Instance.h
+++ b/src/dawn/native/Instance.h
@@ -27,6 +27,7 @@
 #include "dawn/native/BackendConnection.h"
 #include "dawn/native/BlobCache.h"
 #include "dawn/native/Features.h"
+#include "dawn/native/RefCountedWithExternalCount.h"
 #include "dawn/native/Toggles.h"
 #include "dawn/native/dawn_platform.h"
 
@@ -45,7 +46,7 @@
 
 // This is called InstanceBase for consistency across the frontend, even if the backends don't
 // specialize this class.
-class InstanceBase final : public RefCounted {
+class InstanceBase final : public RefCountedWithExternalCount {
   public:
     static Ref<InstanceBase> Create(const InstanceDescriptor* descriptor = nullptr);
 
@@ -110,6 +111,8 @@
     InstanceBase();
     ~InstanceBase() override;
 
+    void WillDropLastExternalRef() override;
+
     InstanceBase(const InstanceBase& other) = delete;
     InstanceBase& operator=(const InstanceBase& other) = delete;
 
diff --git a/src/dawn/tests/end2end/DeviceInitializationTests.cpp b/src/dawn/tests/end2end/DeviceInitializationTests.cpp
index 0c7e621..023b7ac 100644
--- a/src/dawn/tests/end2end/DeviceInitializationTests.cpp
+++ b/src/dawn/tests/end2end/DeviceInitializationTests.cpp
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include <memory>
+#include <utility>
 #include <vector>
 
 #include "dawn/dawn_proc.h"
@@ -21,9 +22,48 @@
 #include "dawn/utils/WGPUHelpers.h"
 
 class DeviceInitializationTest : public testing::Test {
+  protected:
     void SetUp() override { dawnProcSetProcs(&dawn::native::GetProcs()); }
 
     void TearDown() override { dawnProcSetProcs(nullptr); }
+
+    // Test that the device can still be used by testing a buffer copy.
+    void ExpectDeviceUsable(wgpu::Device device) {
+        wgpu::Buffer src =
+            utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::CopySrc, {1, 2, 3, 4});
+
+        wgpu::Buffer dst = utils::CreateBufferFromData<uint32_t>(
+            device, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, {0, 0, 0, 0});
+
+        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+        encoder.CopyBufferToBuffer(src, 0, dst, 0, 4 * sizeof(uint32_t));
+
+        wgpu::CommandBuffer commands = encoder.Finish();
+        device.GetQueue().Submit(1, &commands);
+
+        bool done = false;
+        dst.MapAsync(
+            wgpu::MapMode::Read, 0, 4 * sizeof(uint32_t),
+            [](WGPUBufferMapAsyncStatus status, void* userdata) {
+                EXPECT_EQ(status, WGPUBufferMapAsyncStatus_Success);
+                *static_cast<bool*>(userdata) = true;
+            },
+            &done);
+
+        // Note: we can't actually test this if Tick moves over to
+        // wgpuInstanceProcessEvents. We can still test that object creation works
+        // without crashing.
+        while (!done) {
+            device.Tick();
+            utils::USleep(100);
+        }
+
+        const uint32_t* mapping = static_cast<const uint32_t*>(dst.GetConstMappedRange());
+        EXPECT_EQ(mapping[0], 1u);
+        EXPECT_EQ(mapping[1], 2u);
+        EXPECT_EQ(mapping[2], 3u);
+        EXPECT_EQ(mapping[3], 4u);
+    }
 };
 
 // Test that device operations are still valid if the reference to the instance
@@ -66,40 +106,64 @@
             }
         }
 
-        // Now, test that the device can still be used by testing a buffer copy.
-        wgpu::Buffer src =
-            utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::CopySrc, {1, 2, 3, 4});
+        if (device) {
+            ExpectDeviceUsable(std::move(device));
+        }
+    }
+}
 
-        wgpu::Buffer dst = utils::CreateBufferFromData<uint32_t>(
-            device, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, {0, 0, 0, 0});
+// Test that it is still possible to create a device from an adapter after the reference to the
+// instance is dropped.
+TEST_F(DeviceInitializationTest, AdapterOutlivesInstance) {
+    // Get properties of all available adapters and then free the instance.
+    // We want to create a device on a fresh instance and adapter each time.
+    std::vector<wgpu::AdapterProperties> availableAdapterProperties;
+    {
+        auto instance = std::make_unique<dawn::native::Instance>();
+        instance->DiscoverDefaultAdapters();
+        for (const dawn::native::Adapter& adapter : instance->GetAdapters()) {
+            wgpu::AdapterProperties properties;
+            adapter.GetProperties(&properties);
 
-        wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-        encoder.CopyBufferToBuffer(src, 0, dst, 0, 4 * sizeof(uint32_t));
+            if (properties.backendType == wgpu::BackendType::Null) {
+                continue;
+            }
+            availableAdapterProperties.push_back(properties);
+        }
+    }
 
-        wgpu::CommandBuffer commands = encoder.Finish();
-        device.GetQueue().Submit(1, &commands);
+    for (const wgpu::AdapterProperties& desiredProperties : availableAdapterProperties) {
+        wgpu::Adapter adapter;
 
-        bool done = false;
-        dst.MapAsync(
-            wgpu::MapMode::Read, 0, 4 * sizeof(uint32_t),
-            [](WGPUBufferMapAsyncStatus status, void* userdata) {
-                EXPECT_EQ(status, WGPUBufferMapAsyncStatus_Success);
-                *static_cast<bool*>(userdata) = true;
-            },
-            &done);
+        auto instance = std::make_unique<dawn::native::Instance>();
+        // Save a pointer to the instance.
+        // It will only be valid as long as the instance is alive.
+        WGPUInstance unsafeInstancePtr = instance->Get();
 
-        // Note: we can't actually test this if Tick moves over to
-        // wgpuInstanceProcessEvents. We can still test that object creation works
-        // without crashing.
-        while (!done) {
-            device.Tick();
-            utils::USleep(100);
+        instance->DiscoverDefaultAdapters();
+        for (dawn::native::Adapter& nativeAdapter : instance->GetAdapters()) {
+            wgpu::AdapterProperties properties;
+            nativeAdapter.GetProperties(&properties);
+
+            if (properties.deviceID == desiredProperties.deviceID &&
+                properties.vendorID == desiredProperties.vendorID &&
+                properties.adapterType == desiredProperties.adapterType &&
+                properties.backendType == desiredProperties.backendType) {
+                // Save the adapter, and reset the instance.
+                // Check that the number of adapters before the reset is > 0, and after the reset
+                // is 0. Unsafe, but we assume the pointer is still valid since the adapter is
+                // holding onto the instance. The instance should have cleared all internal
+                // references to adapters when the last external ref is dropped.
+                adapter = wgpu::Adapter(nativeAdapter.Get());
+                EXPECT_GT(dawn::native::GetAdapterCountForTesting(unsafeInstancePtr), 0u);
+                instance.reset();
+                EXPECT_EQ(dawn::native::GetAdapterCountForTesting(unsafeInstancePtr), 0u);
+                break;
+            }
         }
 
-        const uint32_t* mapping = static_cast<const uint32_t*>(dst.GetConstMappedRange());
-        EXPECT_EQ(mapping[0], 1u);
-        EXPECT_EQ(mapping[1], 2u);
-        EXPECT_EQ(mapping[2], 3u);
-        EXPECT_EQ(mapping[3], 4u);
+        if (adapter) {
+            ExpectDeviceUsable(adapter.CreateDevice());
+        }
     }
 }
diff --git a/src/dawn/tests/unittests/ToBackendTests.cpp b/src/dawn/tests/unittests/ToBackendTests.cpp
index 91c4b66..8720a27 100644
--- a/src/dawn/tests/unittests/ToBackendTests.cpp
+++ b/src/dawn/tests/unittests/ToBackendTests.cpp
@@ -19,17 +19,23 @@
 #include "dawn/common/RefCounted.h"
 #include "dawn/native/ToBackend.h"
 
-// Make our own Base - Backend object pair, reusing the AdapterBase name
+// Make our own Base - Backend object pair, reusing the MyObjectBase name
 namespace dawn::native {
-class AdapterBase : public RefCounted {};
 
-class MyAdapter : public AdapterBase {};
+class MyObjectBase : public RefCounted {};
+
+class MyObject : public MyObjectBase {};
 
 struct MyBackendTraits {
-    using AdapterType = MyAdapter;
+    using MyObjectType = MyObject;
 };
 
-// Instanciate ToBackend for our "backend"
+template <typename BackendTraits>
+struct ToBackendTraits<MyObjectBase, BackendTraits> {
+    using BackendType = typename BackendTraits::MyObjectType;
+};
+
+// Instantiate ToBackend for our "backend"
 template <typename T>
 auto ToBackend(T&& common) -> decltype(ToBackendBase<MyBackendTraits>(common)) {
     return ToBackendBase<MyBackendTraits>(common);
@@ -38,49 +44,48 @@
 // Test that ToBackend correctly converts pointers to base classes.
 TEST(ToBackend, Pointers) {
     {
-        MyAdapter* adapter = new MyAdapter;
-        const AdapterBase* base = adapter;
+        MyObject* myObject = new MyObject;
+        const MyObjectBase* base = myObject;
 
         auto* backendAdapter = ToBackend(base);
-        static_assert(std::is_same<decltype(backendAdapter), const MyAdapter*>::value);
-        ASSERT_EQ(adapter, backendAdapter);
+        static_assert(std::is_same<decltype(backendAdapter), const MyObject*>::value);
+        ASSERT_EQ(myObject, backendAdapter);
 
-        adapter->Release();
+        myObject->Release();
     }
     {
-        MyAdapter* adapter = new MyAdapter;
-        AdapterBase* base = adapter;
+        MyObject* myObject = new MyObject;
+        MyObjectBase* base = myObject;
 
         auto* backendAdapter = ToBackend(base);
-        static_assert(std::is_same<decltype(backendAdapter), MyAdapter*>::value);
-        ASSERT_EQ(adapter, backendAdapter);
+        static_assert(std::is_same<decltype(backendAdapter), MyObject*>::value);
+        ASSERT_EQ(myObject, backendAdapter);
 
-        adapter->Release();
+        myObject->Release();
     }
 }
 
 // Test that ToBackend correctly converts Refs to base classes.
 TEST(ToBackend, Ref) {
     {
-        MyAdapter* adapter = new MyAdapter;
-        const Ref<AdapterBase> base(adapter);
+        MyObject* myObject = new MyObject;
+        const Ref<MyObjectBase> base(myObject);
 
         const auto& backendAdapter = ToBackend(base);
-        static_assert(std::is_same<decltype(ToBackend(base)), const Ref<MyAdapter>&>::value);
-        ASSERT_EQ(adapter, backendAdapter.Get());
+        static_assert(std::is_same<decltype(ToBackend(base)), const Ref<MyObject>&>::value);
+        ASSERT_EQ(myObject, backendAdapter.Get());
 
-        adapter->Release();
+        myObject->Release();
     }
     {
-        MyAdapter* adapter = new MyAdapter;
-        Ref<AdapterBase> base(adapter);
+        MyObject* myObject = new MyObject;
+        Ref<MyObjectBase> base(myObject);
 
         auto backendAdapter = ToBackend(base);
-        static_assert(std::is_same<decltype(ToBackend(base)), Ref<MyAdapter>&>::value);
-        ASSERT_EQ(adapter, backendAdapter.Get());
+        static_assert(std::is_same<decltype(ToBackend(base)), Ref<MyObject>&>::value);
+        ASSERT_EQ(myObject, backendAdapter.Get());
 
-        adapter->Release();
+        myObject->Release();
     }
 }
-
 }  // namespace dawn::native