Reland "Refactor AsyncTask into an object with helpers." Added locking of AsyncTask's mutex while constructing the task. It is possible that the task is run and completes before PostWorkerTask returns. This is a reland of commit 773146ed837bac5363753246e4e2ea3f476f20b1 Original change's description: > Refactor AsyncTask into an object with helpers. > > AsyncTaskManager::PostTask now constructs a task object to wrap the > running task and returns it. Allow for overriding the AsyncTask class > with a template on PostTask. > > Update std::lock_guard to the more modern std::scoped_lock in > AsyncTaskManager. > > Bug: 406520956 > > Change-Id: Ie66e5b704db3772117885ce71ea0308b13c093b7 > Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/251114 > Reviewed-by: Corentin Wallez <cwallez@chromium.org> > Reviewed-by: Loko Kung <lokokung@google.com> > Commit-Queue: Geoff Lang <geofflang@chromium.org> Bug: 406520956 Change-Id: I0d2b14bd05a5d4d88a5aec20ac3ec5c4bb2e0555 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/263714 Commit-Queue: Geoff Lang <geofflang@chromium.org> Reviewed-by: Loko Kung <lokokung@google.com>
diff --git a/src/dawn/native/AsyncTask.cpp b/src/dawn/native/AsyncTask.cpp index a9fb8ef..2821bb3 100644 --- a/src/dawn/native/AsyncTask.cpp +++ b/src/dawn/native/AsyncTask.cpp
@@ -33,62 +33,119 @@ namespace dawn::native { +AsyncTask::AsyncTask(std::function<void()> task) : mTask(task) {} + +void AsyncTask::Wait() { + std::unique_ptr<dawn::platform::WaitableEvent> waitableEvent; + { + std::scoped_lock<std::mutex> lock(mMutex); + waitableEvent = std::move(mWaitableEvent); + } + + if (waitableEvent) { + waitableEvent->Wait(); + } +} + +void AsyncTask::AddCompletionCallback(AsyncTaskCompletionCallback completionCallback) { + std::scoped_lock<std::mutex> lock(mMutex); + + // If this task has already completed, call the completion callback immediately. + if (mState == AsyncTaskState::Completed) { + completionCallback(); + return; + } + + mCompletionCallbacks.push_back(completionCallback); +} + +void AsyncTask::Run() { + { + AsyncTaskState prevState = mState.exchange(AsyncTaskState::Running); + DAWN_ASSERT(prevState == AsyncTaskState::Pending); + } + + mTask(); + + // AsyncTask may have a much longer life time than the task itself. + // Reset it to release any references that were captured. + mTask = nullptr; + + // Grab the completion callbacks while locked but call them outside the lock. + std::vector<AsyncTaskCompletionCallback> completionCallbacks; + { + std::scoped_lock<std::mutex> lock(mMutex); + AsyncTaskState prevState = mState.exchange(AsyncTaskState::Completed); + DAWN_ASSERT(prevState == AsyncTaskState::Running); + completionCallbacks = std::move(mCompletionCallbacks); + mCompletionCallbacks.clear(); + mWaitableEvent = nullptr; + } + + for (auto completionCallback : completionCallbacks) { + completionCallback(); + } +} + AsyncTaskManager::AsyncTaskManager(dawn::platform::WorkerTaskPool* workerTaskPool) : mWorkerTaskPool(workerTaskPool) {} -void AsyncTaskManager::PostTask(AsyncTask asyncTask) { - // If these allocations becomes expensive, we can slab-allocate tasks. - Ref<WaitableTask> waitableTask = AcquireRef(new WaitableTask()); - waitableTask->taskManager = this; - waitableTask->asyncTask = std::move(asyncTask); +AsyncTaskManager::~AsyncTaskManager() { + // Pending tasks call back into this task manager. Make sure they all finish before destructing. + WaitAllPendingTasks(); +} - { - // We insert new waitableTask objects into mPendingTasks in main thread (PostTask()), - // and we may remove waitableTask objects from mPendingTasks in either main thread - // (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be - // protected by a mutex. - std::lock_guard<std::mutex> lock(mPendingTasksMutex); - mPendingTasks.emplace(waitableTask.Get(), waitableTask); - } +void AsyncTaskManager::PostConstructedTask(Ref<AsyncTask> asyncTask) { + // We insert new waitableTask objects into mPendingTasks in main thread (PostTask()), + // and we may remove waitableTask objects from mPendingTasks in either main thread + // (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be + // protected by a mutex. + // Hold the mutex until the task is fully posted otherwise it could complete and be deleted + // from mPending tasks before it is fully initialized. + mPendingTasks.Use( + [&asyncTask, taskManager = this, taskPool = mWorkerTaskPool](auto pendingTasks) { + // If these allocations becomes expensive, we can slab-allocate tasks. + auto iter = pendingTasks->emplace(std::make_unique<WaitableTask>()); - // Ref the task since it is accessed inside the worker function. - // The worker function will acquire and release the task upon completion. - waitableTask->AddRef(); - waitableTask->waitableEvent = - mWorkerTaskPool->PostWorkerTask(DoWaitableTask, waitableTask.Get()); + // Should never be inserting the same value twice. + DAWN_ASSERT(iter.second); + + WaitableTask* waitableTask = iter.first->get(); + waitableTask->taskManager = taskManager; + waitableTask->asyncTask = asyncTask; + + // Hold the task's mutex while writing to mWaitableEvent. The task could run and try to + // modify the waitable event while this write is happening. + std::scoped_lock<std::mutex> lock(asyncTask->mMutex); + asyncTask->mWaitableEvent = taskPool->PostWorkerTask(RunTask, waitableTask); + }); } void AsyncTaskManager::HandleTaskCompletion(WaitableTask* task) { - std::lock_guard<std::mutex> lock(mPendingTasksMutex); - mPendingTasks.erase(task); + DAWN_ASSERT(task); + DAWN_ASSERT(task->asyncTask->GetState() == AsyncTaskState::Completed); + + mPendingTasks.Use([&task](auto pendingTasks) { return pendingTasks->erase(task); }); } void AsyncTaskManager::WaitAllPendingTasks() { - absl::flat_hash_map<WaitableTask*, Ref<WaitableTask>> allPendingTasks; + PendingTasksSet allPendingTasks; + mPendingTasks.Use( + [&allPendingTasks](auto pendingTasks) { allPendingTasks.swap(*pendingTasks); }); - { - std::lock_guard<std::mutex> lock(mPendingTasksMutex); - allPendingTasks.swap(mPendingTasks); - } - - for (auto& [_, task] : allPendingTasks) { - task->waitableEvent->Wait(); + for (auto& task : allPendingTasks) { + task->asyncTask->Wait(); } } bool AsyncTaskManager::HasPendingTasks() { - std::lock_guard<std::mutex> lock(mPendingTasksMutex); - return !mPendingTasks.empty(); + return mPendingTasks.Use([](auto pendingTasks) { return !pendingTasks->empty(); }); } -void AsyncTaskManager::DoWaitableTask(void* task) { - Ref<WaitableTask> waitableTask = AcquireRef(static_cast<WaitableTask*>(task)); - waitableTask->asyncTask(); - waitableTask->taskManager->HandleTaskCompletion(waitableTask.Get()); +void AsyncTaskManager::RunTask(void* task) { + WaitableTask* waitableTask = static_cast<WaitableTask*>(task); + waitableTask->asyncTask->Run(); + waitableTask->taskManager->HandleTaskCompletion(waitableTask); } -AsyncTaskManager::WaitableTask::WaitableTask() = default; - -AsyncTaskManager::WaitableTask::~WaitableTask() = default; - } // namespace dawn::native
diff --git a/src/dawn/native/AsyncTask.h b/src/dawn/native/AsyncTask.h index 6457b92..ef310e2 100644 --- a/src/dawn/native/AsyncTask.h +++ b/src/dawn/native/AsyncTask.h
@@ -31,8 +31,12 @@ #include <functional> #include <memory> #include <mutex> +#include <utility> +#include <vector> -#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "dawn/common/MutexProtected.h" +#include "dawn/common/NonCopyable.h" #include "dawn/common/Ref.h" #include "dawn/common/RefCounted.h" #include "partition_alloc/pointers/raw_ptr.h" @@ -49,32 +53,74 @@ // shutting down the device. RunNow() could be used for more advanced scenarios, for example // always doing ShaderModule initial compilation asynchronously, but being able to steal the // task if we need it for synchronous pipeline compilation. -using AsyncTask = std::function<void()>; + +enum class AsyncTaskState : uint8_t { + Pending = 0, + Running = 1, + Completed = 2, +}; + +using AsyncTaskFunction = std::function<void()>; +using AsyncTaskCompletionCallback = std::function<void()>; + +class AsyncTask : public RefCounted { + public: + explicit AsyncTask(AsyncTaskFunction task); + + AsyncTaskState GetState() const { return mState; } + + void Wait(); + + void AddCompletionCallback(AsyncTaskCompletionCallback completionCallback); + + private: + // Friends with the task manager to privately manage when tasks are executed. + friend class AsyncTaskManager; + void Run(); + + AsyncTaskFunction mTask; + + // Use a mutex to guard changes to mCompletionCallbacks, mWaitableEvent and transitioning mState + // to Completed. + // mState is atomic for a lockless getter and Pending -> Running transition. + std::mutex mMutex; + std::atomic<AsyncTaskState> mState = AsyncTaskState::Pending; + std::vector<AsyncTaskCompletionCallback> mCompletionCallbacks; + + // Hold onto the waitable platform event until the task has completed. Released before the + // destruction of the AsyncTask to be as light weight as possible. + std::unique_ptr<dawn::platform::WaitableEvent> mWaitableEvent; +}; class AsyncTaskManager { public: explicit AsyncTaskManager(dawn::platform::WorkerTaskPool* workerTaskPool); + ~AsyncTaskManager(); - void PostTask(AsyncTask asyncTask); + template <typename TaskType, class... Args> + Ref<TaskType> PostTask(Args&&... args) { + Ref<TaskType> asyncTask = AcquireRef(new TaskType(std::forward<Args>(args)...)); + PostConstructedTask(asyncTask); + return asyncTask; + } + void WaitAllPendingTasks(); bool HasPendingTasks(); private: - class WaitableTask : public RefCounted { - public: - WaitableTask(); - ~WaitableTask() override; - - AsyncTask asyncTask; + struct WaitableTask : NonCopyable { + Ref<AsyncTask> asyncTask; raw_ptr<AsyncTaskManager> taskManager; - std::unique_ptr<dawn::platform::WaitableEvent> waitableEvent; }; - static void DoWaitableTask(void* task); + void PostConstructedTask(Ref<AsyncTask> asyncTask); + + static void RunTask(void* task); void HandleTaskCompletion(WaitableTask* task); - std::mutex mPendingTasksMutex; - absl::flat_hash_map<WaitableTask*, Ref<WaitableTask>> mPendingTasks; + using PendingTasksSet = absl::flat_hash_set<std::unique_ptr<WaitableTask>>; + MutexProtected<PendingTasksSet> mPendingTasks; + raw_ptr<dawn::platform::WorkerTaskPool> mWorkerTaskPool; };
diff --git a/src/dawn/native/CreatePipelineAsyncEvent.cpp b/src/dawn/native/CreatePipelineAsyncEvent.cpp index ba506a1..a67fa64 100644 --- a/src/dawn/native/CreatePipelineAsyncEvent.cpp +++ b/src/dawn/native/CreatePipelineAsyncEvent.cpp
@@ -174,7 +174,7 @@ "CreatePipelineAsyncEvent::InitializeAsync", this, "label", eventLabel); auto asyncTask = [event = Ref<CreatePipelineAsyncEvent>(this)] { event->InitializeImpl(true); }; - device->GetAsyncTaskManager()->PostTask(std::move(asyncTask)); + device->GetAsyncTaskManager()->PostTask<AsyncTask>(std::move(asyncTask)); } template <typename PipelineType, typename CreatePipelineAsyncCallbackInfo>
diff --git a/src/dawn/tests/unittests/AsyncTaskTests.cpp b/src/dawn/tests/unittests/AsyncTaskTests.cpp index 6db4c33..65a1ace 100644 --- a/src/dawn/tests/unittests/AsyncTaskTests.cpp +++ b/src/dawn/tests/unittests/AsyncTaskTests.cpp
@@ -40,7 +40,7 @@ #include "dawn/platform/DawnPlatform.h" #include "gtest/gtest.h" -namespace dawn { +namespace dawn::native { namespace { struct SimpleTaskResult { @@ -82,14 +82,13 @@ platform::Platform platform; std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool(); - native::AsyncTaskManager taskManager(pool.get()); + AsyncTaskManager taskManager(pool.get()); ConcurrentTaskResultQueue taskResultQueue; constexpr size_t kTaskCount = 4u; std::set<uint32_t> idset; for (uint32_t i = 0; i < kTaskCount; ++i) { - native::AsyncTask asyncTask([&taskResultQueue, i] { DoTask(&taskResultQueue, i); }); - taskManager.PostTask(std::move(asyncTask)); + taskManager.PostTask<AsyncTask>([&taskResultQueue, i] { DoTask(&taskResultQueue, i); }); idset.insert(i); } @@ -103,5 +102,65 @@ ASSERT_TRUE(idset.empty()); } +// Test that the task status is updated based on the task's running state +TEST_F(AsyncTaskTest, Status) { + platform::Platform platform; + std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool(); + + AsyncTaskManager taskManager(pool.get()); + ConcurrentTaskResultQueue taskResultQueue; + + // Use a mutex to force the task to wait on the main thread before completing + std::mutex mutex; + std::unique_lock lock(mutex); + + auto task = taskManager.PostTask<AsyncTask>( + [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); }); + + ASSERT_NE(task->GetState(), AsyncTaskState::Completed); + + // Allow the task to complete + lock.unlock(); + task->Wait(); + ASSERT_EQ(task->GetState(), AsyncTaskState::Completed); +} + +// Test coverage of the completion callbacks for tasks +TEST_F(AsyncTaskTest, Callbacks) { + platform::Platform platform; + std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool(); + + AsyncTaskManager taskManager(pool.get()); + ConcurrentTaskResultQueue taskResultQueue; + + // Use a mutex to force the task to wait on the main thread before completing + std::mutex mutex; + std::unique_lock lock(mutex); + auto waitingTaskFunction = taskManager.PostTask<AsyncTask>( + [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); }); + + // Use a completion callback that simply counts how many times it's been called + std::atomic<uint64_t> completionCallbackCounter = 0; + AsyncTaskCompletionCallback completionCallback = [&completionCallbackCounter]() { + completionCallbackCounter++; + }; + + // Spawn a task that waits for the mutex. Add a completion callback and confirm that it's not + // called before the task completes + auto task = taskManager.PostTask<AsyncTask>( + [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); }); + task->AddCompletionCallback(completionCallback); + EXPECT_EQ(completionCallbackCounter, 0u); + + // Allow the task to complete and expect that the completion callback has been called + lock.unlock(); + task->Wait(); + EXPECT_EQ(completionCallbackCounter, 1u); + + // Add another completion callback to the already-completed task and check that it is called + // immediately + task->AddCompletionCallback(completionCallback); + EXPECT_EQ(completionCallbackCounter, 2u); +} } // anonymous namespace -} // namespace dawn +} // namespace dawn::native
diff --git a/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp b/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp index 225bfed..7211ac3 100644 --- a/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp +++ b/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp
@@ -49,12 +49,12 @@ std::atomic_bool done(false); // Simulate that an async task would take a long time to finish. - dawn::native::AsyncTask asyncTask([&done] { + AsyncTaskFunction task([&done] { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); done = true; }); - mDeviceMock->GetAsyncTaskManager()->PostTask(std::move(asyncTask)); + mDeviceMock->GetAsyncTaskManager()->PostTask<AsyncTask>(std::move(task)); DropDevice(); // Dropping the device should force the async task to finish. EXPECT_TRUE(done.load());