Refactor AsyncTask into an object with helpers.
AsyncTaskManager::PostTask now constructs a task object to wrap the
running task and returns it. Allow for overriding the AsyncTask class
with a template on PostTask.
Update std::lock_guard to the more modern std::scoped_lock in
AsyncTaskManager.
Bug: 406520956
Change-Id: Ie66e5b704db3772117885ce71ea0308b13c093b7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/251114
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Geoff Lang <geofflang@chromium.org>
diff --git a/src/dawn/native/AsyncTask.cpp b/src/dawn/native/AsyncTask.cpp
index a9fb8ef..c756d82 100644
--- a/src/dawn/native/AsyncTask.cpp
+++ b/src/dawn/native/AsyncTask.cpp
@@ -33,62 +33,115 @@
namespace dawn::native {
+AsyncTask::AsyncTask(std::function<void()> task) : mTask(task) {}
+
+void AsyncTask::Wait() {
+ std::unique_ptr<dawn::platform::WaitableEvent> waitableEvent;
+ {
+ std::scoped_lock<std::mutex> lock(mMutex);
+ waitableEvent = std::move(mWaitableEvent);
+ }
+
+ if (waitableEvent) {
+ waitableEvent->Wait();
+ }
+}
+
+void AsyncTask::AddCompletionCallback(AsyncTaskCompletionCallback completionCallback) {
+ std::scoped_lock<std::mutex> lock(mMutex);
+
+ // If this task has already completed, call the completion callback immediately.
+ if (mState == AsyncTaskState::Completed) {
+ completionCallback();
+ return;
+ }
+
+ mCompletionCallbacks.push_back(completionCallback);
+}
+
+void AsyncTask::Run() {
+ {
+ AsyncTaskState prevState = mState.exchange(AsyncTaskState::Running);
+ DAWN_ASSERT(prevState == AsyncTaskState::Pending);
+ }
+
+ mTask();
+
+ // AsyncTask may have a much longer life time than the task itself.
+ // Reset it to release any references that were captured.
+ mTask = nullptr;
+
+ // Grab the completion callbacks while locked but call them outside the lock.
+ std::vector<AsyncTaskCompletionCallback> completionCallbacks;
+ {
+ std::scoped_lock<std::mutex> lock(mMutex);
+ AsyncTaskState prevState = mState.exchange(AsyncTaskState::Completed);
+ DAWN_ASSERT(prevState == AsyncTaskState::Running);
+ completionCallbacks = std::move(mCompletionCallbacks);
+ mCompletionCallbacks.clear();
+ mWaitableEvent = nullptr;
+ }
+
+ for (auto completionCallback : completionCallbacks) {
+ completionCallback();
+ }
+}
+
AsyncTaskManager::AsyncTaskManager(dawn::platform::WorkerTaskPool* workerTaskPool)
: mWorkerTaskPool(workerTaskPool) {}
-void AsyncTaskManager::PostTask(AsyncTask asyncTask) {
- // If these allocations becomes expensive, we can slab-allocate tasks.
- Ref<WaitableTask> waitableTask = AcquireRef(new WaitableTask());
- waitableTask->taskManager = this;
- waitableTask->asyncTask = std::move(asyncTask);
+AsyncTaskManager::~AsyncTaskManager() {
+ // Pending tasks call back into this task manager. Make sure they all finish before destructing.
+ WaitAllPendingTasks();
+}
- {
- // We insert new waitableTask objects into mPendingTasks in main thread (PostTask()),
- // and we may remove waitableTask objects from mPendingTasks in either main thread
- // (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be
- // protected by a mutex.
- std::lock_guard<std::mutex> lock(mPendingTasksMutex);
- mPendingTasks.emplace(waitableTask.Get(), waitableTask);
- }
+void AsyncTaskManager::PostConstructedTask(Ref<AsyncTask> asyncTask) {
+ // We insert new waitableTask objects into mPendingTasks in main thread (PostTask()),
+ // and we may remove waitableTask objects from mPendingTasks in either main thread
+ // (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be
+ // protected by a mutex.
+ // Hold the mutex until the task is fully posted otherwise it could complete and be deleted
+ // from mPending tasks before it is fully initialized.
+ mPendingTasks.Use(
+ [&asyncTask, taskManager = this, taskPool = mWorkerTaskPool](auto pendingTasks) {
+ // If these allocations becomes expensive, we can slab-allocate tasks.
+ auto iter = pendingTasks->emplace(std::make_unique<WaitableTask>());
- // Ref the task since it is accessed inside the worker function.
- // The worker function will acquire and release the task upon completion.
- waitableTask->AddRef();
- waitableTask->waitableEvent =
- mWorkerTaskPool->PostWorkerTask(DoWaitableTask, waitableTask.Get());
+ // Should never be inserting the same value twice.
+ DAWN_ASSERT(iter.second);
+
+ WaitableTask* waitableTask = iter.first->get();
+ waitableTask->taskManager = taskManager;
+ waitableTask->asyncTask = asyncTask;
+
+ asyncTask->mWaitableEvent = taskPool->PostWorkerTask(RunTask, waitableTask);
+ });
}
void AsyncTaskManager::HandleTaskCompletion(WaitableTask* task) {
- std::lock_guard<std::mutex> lock(mPendingTasksMutex);
- mPendingTasks.erase(task);
+ DAWN_ASSERT(task);
+ DAWN_ASSERT(task->asyncTask->GetState() == AsyncTaskState::Completed);
+
+ mPendingTasks.Use([&task](auto pendingTasks) { return pendingTasks->erase(task); });
}
void AsyncTaskManager::WaitAllPendingTasks() {
- absl::flat_hash_map<WaitableTask*, Ref<WaitableTask>> allPendingTasks;
+ PendingTasksSet allPendingTasks =
+ mPendingTasks.Use([](auto pendingTasks) { return std::move(*pendingTasks); });
- {
- std::lock_guard<std::mutex> lock(mPendingTasksMutex);
- allPendingTasks.swap(mPendingTasks);
- }
-
- for (auto& [_, task] : allPendingTasks) {
- task->waitableEvent->Wait();
+ for (auto& task : allPendingTasks) {
+ task->asyncTask->Wait();
}
}
bool AsyncTaskManager::HasPendingTasks() {
- std::lock_guard<std::mutex> lock(mPendingTasksMutex);
- return !mPendingTasks.empty();
+ return mPendingTasks.Use([](auto pendingTasks) { return !pendingTasks->empty(); });
}
-void AsyncTaskManager::DoWaitableTask(void* task) {
- Ref<WaitableTask> waitableTask = AcquireRef(static_cast<WaitableTask*>(task));
- waitableTask->asyncTask();
- waitableTask->taskManager->HandleTaskCompletion(waitableTask.Get());
+void AsyncTaskManager::RunTask(void* task) {
+ WaitableTask* waitableTask = static_cast<WaitableTask*>(task);
+ waitableTask->asyncTask->Run();
+ waitableTask->taskManager->HandleTaskCompletion(waitableTask);
}
-AsyncTaskManager::WaitableTask::WaitableTask() = default;
-
-AsyncTaskManager::WaitableTask::~WaitableTask() = default;
-
} // namespace dawn::native
diff --git a/src/dawn/native/AsyncTask.h b/src/dawn/native/AsyncTask.h
index 6457b92..ef310e2 100644
--- a/src/dawn/native/AsyncTask.h
+++ b/src/dawn/native/AsyncTask.h
@@ -31,8 +31,12 @@
#include <functional>
#include <memory>
#include <mutex>
+#include <utility>
+#include <vector>
-#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "dawn/common/MutexProtected.h"
+#include "dawn/common/NonCopyable.h"
#include "dawn/common/Ref.h"
#include "dawn/common/RefCounted.h"
#include "partition_alloc/pointers/raw_ptr.h"
@@ -49,32 +53,74 @@
// shutting down the device. RunNow() could be used for more advanced scenarios, for example
// always doing ShaderModule initial compilation asynchronously, but being able to steal the
// task if we need it for synchronous pipeline compilation.
-using AsyncTask = std::function<void()>;
+
+enum class AsyncTaskState : uint8_t {
+ Pending = 0,
+ Running = 1,
+ Completed = 2,
+};
+
+using AsyncTaskFunction = std::function<void()>;
+using AsyncTaskCompletionCallback = std::function<void()>;
+
+class AsyncTask : public RefCounted {
+ public:
+ explicit AsyncTask(AsyncTaskFunction task);
+
+ AsyncTaskState GetState() const { return mState; }
+
+ void Wait();
+
+ void AddCompletionCallback(AsyncTaskCompletionCallback completionCallback);
+
+ private:
+ // Friends with the task manager to privately manage when tasks are executed.
+ friend class AsyncTaskManager;
+ void Run();
+
+ AsyncTaskFunction mTask;
+
+ // Use a mutex to guard changes to mCompletionCallbacks, mWaitableEvent and transitioning mState
+ // to Completed.
+ // mState is atomic for a lockless getter and Pending -> Running transition.
+ std::mutex mMutex;
+ std::atomic<AsyncTaskState> mState = AsyncTaskState::Pending;
+ std::vector<AsyncTaskCompletionCallback> mCompletionCallbacks;
+
+ // Hold onto the waitable platform event until the task has completed. Released before the
+ // destruction of the AsyncTask to be as light weight as possible.
+ std::unique_ptr<dawn::platform::WaitableEvent> mWaitableEvent;
+};
class AsyncTaskManager {
public:
explicit AsyncTaskManager(dawn::platform::WorkerTaskPool* workerTaskPool);
+ ~AsyncTaskManager();
- void PostTask(AsyncTask asyncTask);
+ template <typename TaskType, class... Args>
+ Ref<TaskType> PostTask(Args&&... args) {
+ Ref<TaskType> asyncTask = AcquireRef(new TaskType(std::forward<Args>(args)...));
+ PostConstructedTask(asyncTask);
+ return asyncTask;
+ }
+
void WaitAllPendingTasks();
bool HasPendingTasks();
private:
- class WaitableTask : public RefCounted {
- public:
- WaitableTask();
- ~WaitableTask() override;
-
- AsyncTask asyncTask;
+ struct WaitableTask : NonCopyable {
+ Ref<AsyncTask> asyncTask;
raw_ptr<AsyncTaskManager> taskManager;
- std::unique_ptr<dawn::platform::WaitableEvent> waitableEvent;
};
- static void DoWaitableTask(void* task);
+ void PostConstructedTask(Ref<AsyncTask> asyncTask);
+
+ static void RunTask(void* task);
void HandleTaskCompletion(WaitableTask* task);
- std::mutex mPendingTasksMutex;
- absl::flat_hash_map<WaitableTask*, Ref<WaitableTask>> mPendingTasks;
+ using PendingTasksSet = absl::flat_hash_set<std::unique_ptr<WaitableTask>>;
+ MutexProtected<PendingTasksSet> mPendingTasks;
+
raw_ptr<dawn::platform::WorkerTaskPool> mWorkerTaskPool;
};
diff --git a/src/dawn/native/CreatePipelineAsyncEvent.cpp b/src/dawn/native/CreatePipelineAsyncEvent.cpp
index 5f8521b..6eb5ce0 100644
--- a/src/dawn/native/CreatePipelineAsyncEvent.cpp
+++ b/src/dawn/native/CreatePipelineAsyncEvent.cpp
@@ -176,7 +176,7 @@
"CreatePipelineAsyncEvent::InitializeAsync", this, "label", eventLabel);
auto asyncTask = [event = Ref<CreatePipelineAsyncEvent>(this)] { event->InitializeImpl(true); };
- device->GetAsyncTaskManager()->PostTask(std::move(asyncTask));
+ device->GetAsyncTaskManager()->PostTask<AsyncTask>(std::move(asyncTask));
}
template <typename PipelineType, typename CreatePipelineAsyncCallbackInfo>
diff --git a/src/dawn/tests/unittests/AsyncTaskTests.cpp b/src/dawn/tests/unittests/AsyncTaskTests.cpp
index 6db4c33..65a1ace 100644
--- a/src/dawn/tests/unittests/AsyncTaskTests.cpp
+++ b/src/dawn/tests/unittests/AsyncTaskTests.cpp
@@ -40,7 +40,7 @@
#include "dawn/platform/DawnPlatform.h"
#include "gtest/gtest.h"
-namespace dawn {
+namespace dawn::native {
namespace {
struct SimpleTaskResult {
@@ -82,14 +82,13 @@
platform::Platform platform;
std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
- native::AsyncTaskManager taskManager(pool.get());
+ AsyncTaskManager taskManager(pool.get());
ConcurrentTaskResultQueue taskResultQueue;
constexpr size_t kTaskCount = 4u;
std::set<uint32_t> idset;
for (uint32_t i = 0; i < kTaskCount; ++i) {
- native::AsyncTask asyncTask([&taskResultQueue, i] { DoTask(&taskResultQueue, i); });
- taskManager.PostTask(std::move(asyncTask));
+ taskManager.PostTask<AsyncTask>([&taskResultQueue, i] { DoTask(&taskResultQueue, i); });
idset.insert(i);
}
@@ -103,5 +102,65 @@
ASSERT_TRUE(idset.empty());
}
+// Test that the task status is updated based on the task's running state
+TEST_F(AsyncTaskTest, Status) {
+ platform::Platform platform;
+ std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
+
+ AsyncTaskManager taskManager(pool.get());
+ ConcurrentTaskResultQueue taskResultQueue;
+
+ // Use a mutex to force the task to wait on the main thread before completing
+ std::mutex mutex;
+ std::unique_lock lock(mutex);
+
+ auto task = taskManager.PostTask<AsyncTask>(
+ [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); });
+
+ ASSERT_NE(task->GetState(), AsyncTaskState::Completed);
+
+ // Allow the task to complete
+ lock.unlock();
+ task->Wait();
+ ASSERT_EQ(task->GetState(), AsyncTaskState::Completed);
+}
+
+// Test coverage of the completion callbacks for tasks
+TEST_F(AsyncTaskTest, Callbacks) {
+ platform::Platform platform;
+ std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
+
+ AsyncTaskManager taskManager(pool.get());
+ ConcurrentTaskResultQueue taskResultQueue;
+
+ // Use a mutex to force the task to wait on the main thread before completing
+ std::mutex mutex;
+ std::unique_lock lock(mutex);
+ auto waitingTaskFunction = taskManager.PostTask<AsyncTask>(
+ [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); });
+
+ // Use a completion callback that simply counts how many times it's been called
+ std::atomic<uint64_t> completionCallbackCounter = 0;
+ AsyncTaskCompletionCallback completionCallback = [&completionCallbackCounter]() {
+ completionCallbackCounter++;
+ };
+
+ // Spawn a task that waits for the mutex. Add a completion callback and confirm that it's not
+ // called before the task completes
+ auto task = taskManager.PostTask<AsyncTask>(
+ [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); });
+ task->AddCompletionCallback(completionCallback);
+ EXPECT_EQ(completionCallbackCounter, 0u);
+
+ // Allow the task to complete and expect that the completion callback has been called
+ lock.unlock();
+ task->Wait();
+ EXPECT_EQ(completionCallbackCounter, 1u);
+
+ // Add another completion callback to the already-completed task and check that it is called
+ // immediately
+ task->AddCompletionCallback(completionCallback);
+ EXPECT_EQ(completionCallbackCounter, 2u);
+}
} // anonymous namespace
-} // namespace dawn
+} // namespace dawn::native
diff --git a/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp b/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp
index 225bfed..7211ac3 100644
--- a/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp
+++ b/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp
@@ -49,12 +49,12 @@
std::atomic_bool done(false);
// Simulate that an async task would take a long time to finish.
- dawn::native::AsyncTask asyncTask([&done] {
+ AsyncTaskFunction task([&done] {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
done = true;
});
- mDeviceMock->GetAsyncTaskManager()->PostTask(std::move(asyncTask));
+ mDeviceMock->GetAsyncTaskManager()->PostTask<AsyncTask>(std::move(task));
DropDevice();
// Dropping the device should force the async task to finish.
EXPECT_TRUE(done.load());