Clear all callbacks when the Instance is dropped.
- Fixes fuzzer issue in the bug. The userdata was not cleaned up
because an async function (MapAsync) was called on a destroyed
device, and the associated callback was never called in the
fuzzer(s).
- Modernizes the code a bit to use MutexProtected.
- Updates MutexProtected to work with const-ness.
Bug: chromium:1497701
Change-Id: I15617bd00e49e6c2d1c5c2805e4335438b9ce0e2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/158824
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/common/MutexProtected.h b/src/dawn/common/MutexProtected.h
index 6f97d80..051822e 100644
--- a/src/dawn/common/MutexProtected.h
+++ b/src/dawn/common/MutexProtected.h
@@ -79,7 +79,8 @@
const ReturnType& operator*() const { return *Traits::GetObj(mObj); }
private:
- friend class MutexProtected<T>;
+ using NonConstT = typename std::remove_const<T>::type;
+ friend class MutexProtected<NonConstT>;
Guard(T* obj, typename Traits::MutexType& mutex) : mLock(Traits::GetMutex(mutex)), mObj(obj) {}
@@ -142,12 +143,16 @@
auto Use(Fn&& fn) {
return fn(Use());
}
+ template <typename Fn>
+ auto Use(Fn&& fn) const {
+ return fn(Use());
+ }
private:
Usage Use() { return Usage(&mObj, mMutex); }
ConstUsage Use() const { return ConstUsage(&mObj, mMutex); }
- typename Traits::MutexType mMutex;
+ mutable typename Traits::MutexType mMutex;
T mObj;
};
diff --git a/src/dawn/native/CallbackTaskManager.cpp b/src/dawn/native/CallbackTaskManager.cpp
index 1d6bde9..366e3bd 100644
--- a/src/dawn/native/CallbackTaskManager.cpp
+++ b/src/dawn/native/CallbackTaskManager.cpp
@@ -49,10 +49,10 @@
void CallbackTask::Execute() {
switch (mState) {
- case State::HandleDeviceLoss:
+ case CallbackState::DeviceLoss:
HandleDeviceLossImpl();
break;
- case State::HandleShutDown:
+ case CallbackState::ShutDown:
HandleShutDownImpl();
break;
default:
@@ -62,17 +62,17 @@
void CallbackTask::OnShutDown() {
// Only first state change will have effects in final Execute().
- if (mState != State::Normal) {
+ if (mState != CallbackState::Normal) {
return;
}
- mState = State::HandleShutDown;
+ mState = CallbackState::ShutDown;
}
void CallbackTask::OnDeviceLoss() {
- if (mState != State::Normal) {
+ if (mState != CallbackState::Normal) {
return;
}
- mState = State::HandleDeviceLoss;
+ mState = CallbackState::DeviceLoss;
}
CallbackTaskManager::CallbackTaskManager() = default;
@@ -80,13 +80,23 @@
CallbackTaskManager::~CallbackTaskManager() = default;
bool CallbackTaskManager::IsEmpty() {
- std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
- return mCallbackTaskQueue.empty();
+ return mStateAndQueue.Use([](auto stateAndQueue) { return stateAndQueue->mTaskQueue.empty(); });
}
void CallbackTaskManager::AddCallbackTask(std::unique_ptr<CallbackTask> callbackTask) {
- std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
- mCallbackTaskQueue.push_back(std::move(callbackTask));
+ mStateAndQueue.Use([&](auto stateAndQueue) {
+ switch (stateAndQueue->mState) {
+ case CallbackState::ShutDown:
+ callbackTask->OnShutDown();
+ break;
+ case CallbackState::DeviceLoss:
+ callbackTask->OnDeviceLoss();
+ break;
+ default:
+ break;
+ }
+ stateAndQueue->mTaskQueue.push_back(std::move(callbackTask));
+ });
}
void CallbackTaskManager::AddCallbackTask(std::function<void()> callback) {
@@ -94,22 +104,31 @@
}
void CallbackTaskManager::HandleDeviceLoss() {
- std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
- for (auto& task : mCallbackTaskQueue) {
- task->OnDeviceLoss();
- }
+ mStateAndQueue.Use([&](auto stateAndQueue) {
+ if (stateAndQueue->mState != CallbackState::Normal) {
+ return;
+ }
+ stateAndQueue->mState = CallbackState::DeviceLoss;
+ for (auto& task : stateAndQueue->mTaskQueue) {
+ task->OnDeviceLoss();
+ }
+ });
}
void CallbackTaskManager::HandleShutDown() {
- std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
- for (auto& task : mCallbackTaskQueue) {
- task->OnShutDown();
- }
+ mStateAndQueue.Use([&](auto stateAndQueue) {
+ if (stateAndQueue->mState != CallbackState::Normal) {
+ return;
+ }
+ stateAndQueue->mState = CallbackState::ShutDown;
+ for (auto& task : stateAndQueue->mTaskQueue) {
+ task->OnShutDown();
+ }
+ });
}
void CallbackTaskManager::Flush() {
- std::unique_lock<std::mutex> lock(mCallbackTaskQueueMutex);
- if (mCallbackTaskQueue.empty()) {
+ if (IsEmpty()) {
return;
}
@@ -118,10 +137,9 @@
// such reentrant call, we remove all the callback tasks from mCallbackTaskManager,
// update mCallbackTaskManager, then call all the callbacks.
std::vector<std::unique_ptr<CallbackTask>> allTasks;
- allTasks.swap(mCallbackTaskQueue);
- lock.unlock();
+ mStateAndQueue.Use([&](auto stateAndQueue) { allTasks.swap(stateAndQueue->mTaskQueue); });
- for (std::unique_ptr<CallbackTask>& callbackTask : allTasks) {
+ for (auto& callbackTask : allTasks) {
callbackTask->Execute();
}
}
diff --git a/src/dawn/native/CallbackTaskManager.h b/src/dawn/native/CallbackTaskManager.h
index 2a4fbb6..55bffe3 100644
--- a/src/dawn/native/CallbackTaskManager.h
+++ b/src/dawn/native/CallbackTaskManager.h
@@ -33,11 +33,18 @@
#include <mutex>
#include <vector>
+#include "dawn/common/MutexProtected.h"
#include "dawn/common/RefCounted.h"
#include "dawn/common/TypeTraits.h"
namespace dawn::native {
+enum class CallbackState {
+ Normal,
+ ShutDown,
+ DeviceLoss,
+};
+
struct CallbackTask {
public:
virtual ~CallbackTask() = default;
@@ -52,13 +59,7 @@
virtual void HandleDeviceLossImpl() = 0;
private:
- enum class State {
- Normal,
- HandleShutDown,
- HandleDeviceLoss,
- };
-
- State mState = State::Normal;
+ CallbackState mState = CallbackState::Normal;
};
class CallbackTaskManager : public RefCounted {
@@ -80,8 +81,11 @@
void Flush();
private:
- std::mutex mCallbackTaskQueueMutex;
- std::vector<std::unique_ptr<CallbackTask>> mCallbackTaskQueue;
+ struct StateAndQueue {
+ CallbackState mState = CallbackState::Normal;
+ std::vector<std::unique_ptr<CallbackTask>> mTaskQueue;
+ };
+ MutexProtected<StateAndQueue> mStateAndQueue;
};
} // namespace dawn::native
diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp
index 504f3b3..927eac3 100644
--- a/src/dawn/native/Device.cpp
+++ b/src/dawn/native/Device.cpp
@@ -374,8 +374,8 @@
// references, they can no longer get the queue from APIGetQueue().
mQueue = nullptr;
- // Reset callbacks since after this, since after dropping the last external reference, the
- // application may have freed any device-scope memory needed to run the callback.
+ // Reset callbacks since after dropping the last external reference, the application may have
+ // freed any device-scope memory needed to run the callback.
mUncapturedErrorCallback = [](WGPUErrorType, char const* message, void*) {
dawn::WarningLog() << "Uncaptured error after last external device reference dropped.\n"
<< message;
@@ -450,7 +450,7 @@
// inside which the application may destroy the device. Thus, we should be careful not
// to delete objects that are needed inside Tick after callbacks have been called.
// - mCallbackTaskManager is not deleted since we flush the callback queue at the end
- // of Tick(). Note: that flush should always be emtpy since all callbacks are drained
+ // of Tick(). Note: that flush should always be empty since all callbacks are drained
// inside Destroy() so there should be no outstanding tasks holding objects alive.
// - Similiarly, mAsyncTaskManager is not deleted since we use it to return a status
// from Tick() whether or not there is any more pending work.
diff --git a/src/dawn/native/Instance.cpp b/src/dawn/native/Instance.cpp
index 3a26bac..d058ff5 100644
--- a/src/dawn/native/Instance.cpp
+++ b/src/dawn/native/Instance.cpp
@@ -152,6 +152,28 @@
InstanceBase::~InstanceBase() = default;
+void InstanceBase::DeleteThis() {
+ // Flush all remaining callback tasks on all devices and on the instance.
+ std::set<DeviceBase*> devices;
+ do {
+ devices.clear();
+ mDevicesList.Use([&](auto deviceList) { devices.swap(*deviceList); });
+ for (auto device : devices) {
+ device->GetCallbackTaskManager()->HandleShutDown();
+ do {
+ device->GetCallbackTaskManager()->Flush();
+ } while (!device->GetCallbackTaskManager()->IsEmpty());
+ }
+ } while (!devices.empty());
+
+ mCallbackTaskManager->HandleShutDown();
+ do {
+ mCallbackTaskManager->Flush();
+ } while (!mCallbackTaskManager->IsEmpty());
+
+ RefCountedWithExternalCount::DeleteThis();
+}
+
void InstanceBase::WillDropLastExternalRef() {
// InstanceBase uses RefCountedWithExternalCount to break refcycles.
@@ -451,28 +473,24 @@
}
uint64_t InstanceBase::GetDeviceCountForTesting() const {
- std::lock_guard<std::mutex> lg(mDevicesListMutex);
- return mDevicesList.size();
+ return mDevicesList.Use([](auto deviceList) { return deviceList->size(); });
}
void InstanceBase::AddDevice(DeviceBase* device) {
- std::lock_guard<std::mutex> lg(mDevicesListMutex);
- mDevicesList.insert(device);
+ mDevicesList.Use([&](auto deviceList) { deviceList->insert(device); });
}
void InstanceBase::RemoveDevice(DeviceBase* device) {
- std::lock_guard<std::mutex> lg(mDevicesListMutex);
- mDevicesList.erase(device);
+ mDevicesList.Use([&](auto deviceList) { deviceList->erase(device); });
}
void InstanceBase::APIProcessEvents() {
std::vector<Ref<DeviceBase>> devices;
- {
- std::lock_guard<std::mutex> lg(mDevicesListMutex);
- for (auto device : mDevicesList) {
+ mDevicesList.Use([&](auto deviceList) {
+ for (auto device : *deviceList) {
devices.push_back(device);
}
- }
+ });
for (auto device : devices) {
device->APITick();
diff --git a/src/dawn/native/Instance.h b/src/dawn/native/Instance.h
index 44dc007..56b3e60 100644
--- a/src/dawn/native/Instance.h
+++ b/src/dawn/native/Instance.h
@@ -36,6 +36,7 @@
#include <unordered_set>
#include <vector>
+#include "dawn/common/MutexProtected.h"
#include "dawn/common/Ref.h"
#include "dawn/common/ityp_array.h"
#include "dawn/common/ityp_bitset.h"
@@ -164,6 +165,7 @@
explicit InstanceBase(const TogglesState& instanceToggles);
~InstanceBase() override;
+ void DeleteThis() override;
void WillDropLastExternalRef() override;
InstanceBase(const InstanceBase& other) = delete;
@@ -215,8 +217,7 @@
Ref<CallbackTaskManager> mCallbackTaskManager;
EventManager mEventManager;
- std::set<DeviceBase*> mDevicesList;
- mutable std::mutex mDevicesListMutex;
+ MutexProtected<std::set<DeviceBase*>> mDevicesList;
};
} // namespace dawn::native
diff --git a/src/dawn/native/RefCountedWithExternalCount.h b/src/dawn/native/RefCountedWithExternalCount.h
index 345b209..d27e260 100644
--- a/src/dawn/native/RefCountedWithExternalCount.h
+++ b/src/dawn/native/RefCountedWithExternalCount.h
@@ -47,6 +47,9 @@
void APIReference();
void APIRelease();
+ protected:
+ using RefCounted::DeleteThis;
+
private:
virtual void WillDropLastExternalRef() = 0;