Clear all callbacks when the Instance is dropped.

- Fixes fuzzer issue in the bug. The userdata was not cleaned up
  because an async function (MapAsync) was called on a destroyed
  device, and the associated callback was never called in the
  fuzzer(s).
- Modernizes the code a bit to use MutexProtected.
- Updates MutexProtected to work with const-ness.

Bug: chromium:1497701
Change-Id: I15617bd00e49e6c2d1c5c2805e4335438b9ce0e2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/158824
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
index 6f97d80..051822e 100644
--- a/src/dawn/common/MutexProtected.h
+++ b/src/dawn/common/MutexProtected.h
@@ -79,7 +79,8 @@
     const ReturnType& operator*() const { return *Traits::GetObj(mObj); }
 
   private:
-    friend class MutexProtected<T>;
+    using NonConstT = typename std::remove_const<T>::type;
+    friend class MutexProtected<NonConstT>;
 
     Guard(T* obj, typename Traits::MutexType& mutex) : mLock(Traits::GetMutex(mutex)), mObj(obj) {}
 
@@ -142,12 +143,16 @@
     auto Use(Fn&& fn) {
         return fn(Use());
     }
+    template <typename Fn>
+    auto Use(Fn&& fn) const {
+        return fn(Use());
+    }
 
   private:
     Usage Use() { return Usage(&mObj, mMutex); }
     ConstUsage Use() const { return ConstUsage(&mObj, mMutex); }
 
-    typename Traits::MutexType mMutex;
+    mutable typename Traits::MutexType mMutex;
     T mObj;
 };
 
diff --git a/src/dawn/native/CallbackTaskManager.cpp b/src/dawn/native/CallbackTaskManager.cpp
index 1d6bde9..366e3bd 100644
--- a/src/dawn/native/CallbackTaskManager.cpp
+++ b/src/dawn/native/CallbackTaskManager.cpp
@@ -49,10 +49,10 @@
 
 void CallbackTask::Execute() {
     switch (mState) {
-        case State::HandleDeviceLoss:
+        case CallbackState::DeviceLoss:
             HandleDeviceLossImpl();
             break;
-        case State::HandleShutDown:
+        case CallbackState::ShutDown:
             HandleShutDownImpl();
             break;
         default:
@@ -62,17 +62,17 @@
 
 void CallbackTask::OnShutDown() {
     // Only first state change will have effects in final Execute().
-    if (mState != State::Normal) {
+    if (mState != CallbackState::Normal) {
         return;
     }
-    mState = State::HandleShutDown;
+    mState = CallbackState::ShutDown;
 }
 
 void CallbackTask::OnDeviceLoss() {
-    if (mState != State::Normal) {
+    if (mState != CallbackState::Normal) {
         return;
     }
-    mState = State::HandleDeviceLoss;
+    mState = CallbackState::DeviceLoss;
 }
 
 CallbackTaskManager::CallbackTaskManager() = default;
@@ -80,13 +80,23 @@
 CallbackTaskManager::~CallbackTaskManager() = default;
 
 bool CallbackTaskManager::IsEmpty() {
-    std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
-    return mCallbackTaskQueue.empty();
+    return mStateAndQueue.Use([](auto stateAndQueue) { return stateAndQueue->mTaskQueue.empty(); });
 }
 
 void CallbackTaskManager::AddCallbackTask(std::unique_ptr<CallbackTask> callbackTask) {
-    std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
-    mCallbackTaskQueue.push_back(std::move(callbackTask));
+    mStateAndQueue.Use([&](auto stateAndQueue) {
+        switch (stateAndQueue->mState) {
+            case CallbackState::ShutDown:
+                callbackTask->OnShutDown();
+                break;
+            case CallbackState::DeviceLoss:
+                callbackTask->OnDeviceLoss();
+                break;
+            default:
+                break;
+        }
+        stateAndQueue->mTaskQueue.push_back(std::move(callbackTask));
+    });
 }
 
 void CallbackTaskManager::AddCallbackTask(std::function<void()> callback) {
@@ -94,22 +104,31 @@
 }
 
 void CallbackTaskManager::HandleDeviceLoss() {
-    std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
-    for (auto& task : mCallbackTaskQueue) {
-        task->OnDeviceLoss();
-    }
+    mStateAndQueue.Use([&](auto stateAndQueue) {
+        if (stateAndQueue->mState != CallbackState::Normal) {
+            return;
+        }
+        stateAndQueue->mState = CallbackState::DeviceLoss;
+        for (auto& task : stateAndQueue->mTaskQueue) {
+            task->OnDeviceLoss();
+        }
+    });
 }
 
 void CallbackTaskManager::HandleShutDown() {
-    std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
-    for (auto& task : mCallbackTaskQueue) {
-        task->OnShutDown();
-    }
+    mStateAndQueue.Use([&](auto stateAndQueue) {
+        if (stateAndQueue->mState != CallbackState::Normal) {
+            return;
+        }
+        stateAndQueue->mState = CallbackState::ShutDown;
+        for (auto& task : stateAndQueue->mTaskQueue) {
+            task->OnShutDown();
+        }
+    });
 }
 
 void CallbackTaskManager::Flush() {
-    std::unique_lock<std::mutex> lock(mCallbackTaskQueueMutex);
-    if (mCallbackTaskQueue.empty()) {
+    if (IsEmpty()) {
         return;
     }
 
@@ -118,10 +137,9 @@
     // such reentrant call, we remove all the callback tasks from mCallbackTaskManager,
     // update mCallbackTaskManager, then call all the callbacks.
     std::vector<std::unique_ptr<CallbackTask>> allTasks;
-    allTasks.swap(mCallbackTaskQueue);
-    lock.unlock();
+    mStateAndQueue.Use([&](auto stateAndQueue) { allTasks.swap(stateAndQueue->mTaskQueue); });
 
-    for (std::unique_ptr<CallbackTask>& callbackTask : allTasks) {
+    for (auto& callbackTask : allTasks) {
         callbackTask->Execute();
     }
 }
diff --git a/src/dawn/native/CallbackTaskManager.h b/src/dawn/native/CallbackTaskManager.h
index 2a4fbb6..55bffe3 100644
--- a/src/dawn/native/CallbackTaskManager.h
+++ b/src/dawn/native/CallbackTaskManager.h
@@ -33,11 +33,18 @@
 #include <mutex>
 #include <vector>
 
+#include "dawn/common/MutexProtected.h"
 #include "dawn/common/RefCounted.h"
 #include "dawn/common/TypeTraits.h"
 
 namespace dawn::native {
 
+enum class CallbackState {
+    Normal,
+    ShutDown,
+    DeviceLoss,
+};
+
 struct CallbackTask {
   public:
     virtual ~CallbackTask() = default;
@@ -52,13 +59,7 @@
     virtual void HandleDeviceLossImpl() = 0;
 
   private:
-    enum class State {
-        Normal,
-        HandleShutDown,
-        HandleDeviceLoss,
-    };
-
-    State mState = State::Normal;
+    CallbackState mState = CallbackState::Normal;
 };
 
 class CallbackTaskManager : public RefCounted {
@@ -80,8 +81,11 @@
     void Flush();
 
   private:
-    std::mutex mCallbackTaskQueueMutex;
-    std::vector<std::unique_ptr<CallbackTask>> mCallbackTaskQueue;
+    struct StateAndQueue {
+        CallbackState mState = CallbackState::Normal;
+        std::vector<std::unique_ptr<CallbackTask>> mTaskQueue;
+    };
+    MutexProtected<StateAndQueue> mStateAndQueue;
 };
 
 }  // namespace dawn::native
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 504f3b3..927eac3 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -374,8 +374,8 @@
     // references, they can no longer get the queue from APIGetQueue().
     mQueue = nullptr;
 
-    // Reset callbacks since after this, since after dropping the last external reference, the
-    // application may have freed any device-scope memory needed to run the callback.
+    // Reset callbacks since after dropping the last external reference, the application may have
+    // freed any device-scope memory needed to run the callback.
     mUncapturedErrorCallback = [](WGPUErrorType, char const* message, void*) {
         dawn::WarningLog() << "Uncaptured error after last external device reference dropped.\n"
                            << message;
@@ -450,7 +450,7 @@
     // inside which the application may destroy the device. Thus, we should be careful not
     // to delete objects that are needed inside Tick after callbacks have been called.
     //  - mCallbackTaskManager is not deleted since we flush the callback queue at the end
-    // of Tick(). Note: that flush should always be emtpy since all callbacks are drained
+    // of Tick(). Note: that flush should always be empty since all callbacks are drained
     // inside Destroy() so there should be no outstanding tasks holding objects alive.
     //  - Similiarly, mAsyncTaskManager is not deleted since we use it to return a status
     // from Tick() whether or not there is any more pending work.
diff --git a/src/dawn/native/Instance.cpp b/src/dawn/native/Instance.cpp
index 3a26bac..d058ff5 100644
--- a/src/dawn/native/Instance.cpp
+++ b/src/dawn/native/Instance.cpp
@@ -152,6 +152,28 @@
 
 InstanceBase::~InstanceBase() = default;
 
+void InstanceBase::DeleteThis() {
+    // Flush all remaining callback tasks on all devices and on the instance.
+    std::set<DeviceBase*> devices;
+    do {
+        devices.clear();
+        mDevicesList.Use([&](auto deviceList) { devices.swap(*deviceList); });
+        for (auto device : devices) {
+            device->GetCallbackTaskManager()->HandleShutDown();
+            do {
+                device->GetCallbackTaskManager()->Flush();
+            } while (!device->GetCallbackTaskManager()->IsEmpty());
+        }
+    } while (!devices.empty());
+
+    mCallbackTaskManager->HandleShutDown();
+    do {
+        mCallbackTaskManager->Flush();
+    } while (!mCallbackTaskManager->IsEmpty());
+
+    RefCountedWithExternalCount::DeleteThis();
+}
+
 void InstanceBase::WillDropLastExternalRef() {
     // InstanceBase uses RefCountedWithExternalCount to break refcycles.
 
@@ -451,28 +473,24 @@
 }
 
 uint64_t InstanceBase::GetDeviceCountForTesting() const {
-    std::lock_guard<std::mutex> lg(mDevicesListMutex);
-    return mDevicesList.size();
+    return mDevicesList.Use([](auto deviceList) { return deviceList->size(); });
 }
 
 void InstanceBase::AddDevice(DeviceBase* device) {
-    std::lock_guard<std::mutex> lg(mDevicesListMutex);
-    mDevicesList.insert(device);
+    mDevicesList.Use([&](auto deviceList) { deviceList->insert(device); });
 }
 
 void InstanceBase::RemoveDevice(DeviceBase* device) {
-    std::lock_guard<std::mutex> lg(mDevicesListMutex);
-    mDevicesList.erase(device);
+    mDevicesList.Use([&](auto deviceList) { deviceList->erase(device); });
 }
 
 void InstanceBase::APIProcessEvents() {
     std::vector<Ref<DeviceBase>> devices;
-    {
-        std::lock_guard<std::mutex> lg(mDevicesListMutex);
-        for (auto device : mDevicesList) {
+    mDevicesList.Use([&](auto deviceList) {
+        for (auto device : *deviceList) {
             devices.push_back(device);
         }
-    }
+    });
 
     for (auto device : devices) {
         device->APITick();
diff --git a/src/dawn/native/Instance.h b/src/dawn/native/Instance.h
index 44dc007..56b3e60 100644
--- a/src/dawn/native/Instance.h
+++ b/src/dawn/native/Instance.h
@@ -36,6 +36,7 @@
 #include <unordered_set>
 #include <vector>
 
+#include "dawn/common/MutexProtected.h"
 #include "dawn/common/Ref.h"
 #include "dawn/common/ityp_array.h"
 #include "dawn/common/ityp_bitset.h"
@@ -164,6 +165,7 @@
     explicit InstanceBase(const TogglesState& instanceToggles);
     ~InstanceBase() override;
 
+    void DeleteThis() override;
     void WillDropLastExternalRef() override;
 
     InstanceBase(const InstanceBase& other) = delete;
@@ -215,8 +217,7 @@
     Ref<CallbackTaskManager> mCallbackTaskManager;
     EventManager mEventManager;
 
-    std::set<DeviceBase*> mDevicesList;
-    mutable std::mutex mDevicesListMutex;
+    MutexProtected<std::set<DeviceBase*>> mDevicesList;
 };
 
 }  // namespace dawn::native
diff --git a/src/dawn/native/RefCountedWithExternalCount.h b/src/dawn/native/RefCountedWithExternalCount.h
index 345b209..d27e260 100644
--- a/src/dawn/native/RefCountedWithExternalCount.h
+++ b/src/dawn/native/RefCountedWithExternalCount.h
@@ -47,6 +47,9 @@
     void APIReference();
     void APIRelease();
 
+  protected:
+    using RefCounted::DeleteThis;
+
   private:
     virtual void WillDropLastExternalRef() = 0;