Refactor AsyncTask into an object with helpers.

AsyncTaskManager::PostTask now constructs a task object to wrap the
running task and returns it. Allow for overriding the AsyncTask class
with a template on PostTask.

Update std::lock_guard to the more modern std::scoped_lock in
AsyncTaskManager.

Bug: 406520956

Change-Id: Ie66e5b704db3772117885ce71ea0308b13c093b7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/251114
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Geoff Lang <geofflang@chromium.org>
diff --git a/src/dawn/native/AsyncTask.cpp b/src/dawn/native/AsyncTask.cpp
index a9fb8ef..c756d82 100644
--- a/src/dawn/native/AsyncTask.cpp
+++ b/src/dawn/native/AsyncTask.cpp
@@ -33,62 +33,115 @@
 
 namespace dawn::native {
 
+AsyncTask::AsyncTask(std::function<void()> task) : mTask(task) {}
+
+void AsyncTask::Wait() {
+    std::unique_ptr<dawn::platform::WaitableEvent> waitableEvent;
+    {
+        std::scoped_lock<std::mutex> lock(mMutex);
+        waitableEvent = std::move(mWaitableEvent);
+    }
+
+    if (waitableEvent) {
+        waitableEvent->Wait();
+    }
+}
+
+void AsyncTask::AddCompletionCallback(AsyncTaskCompletionCallback completionCallback) {
+    std::scoped_lock<std::mutex> lock(mMutex);
+
+    // If this task has already completed, call the completion callback immediately.
+    if (mState == AsyncTaskState::Completed) {
+        completionCallback();
+        return;
+    }
+
+    mCompletionCallbacks.push_back(completionCallback);
+}
+
+void AsyncTask::Run() {
+    {
+        AsyncTaskState prevState = mState.exchange(AsyncTaskState::Running);
+        DAWN_ASSERT(prevState == AsyncTaskState::Pending);
+    }
+
+    mTask();
+
+    // AsyncTask may have a much longer life time than the task itself.
+    // Reset it to release any references that were captured.
+    mTask = nullptr;
+
+    // Grab the completion callbacks while locked but call them outside the lock.
+    std::vector<AsyncTaskCompletionCallback> completionCallbacks;
+    {
+        std::scoped_lock<std::mutex> lock(mMutex);
+        AsyncTaskState prevState = mState.exchange(AsyncTaskState::Completed);
+        DAWN_ASSERT(prevState == AsyncTaskState::Running);
+        completionCallbacks = std::move(mCompletionCallbacks);
+        mCompletionCallbacks.clear();
+        mWaitableEvent = nullptr;
+    }
+
+    for (auto completionCallback : completionCallbacks) {
+        completionCallback();
+    }
+}
+
 AsyncTaskManager::AsyncTaskManager(dawn::platform::WorkerTaskPool* workerTaskPool)
     : mWorkerTaskPool(workerTaskPool) {}
 
-void AsyncTaskManager::PostTask(AsyncTask asyncTask) {
-    // If these allocations becomes expensive, we can slab-allocate tasks.
-    Ref<WaitableTask> waitableTask = AcquireRef(new WaitableTask());
-    waitableTask->taskManager = this;
-    waitableTask->asyncTask = std::move(asyncTask);
+AsyncTaskManager::~AsyncTaskManager() {
+    // Pending tasks call back into this task manager. Make sure they all finish before destructing.
+    WaitAllPendingTasks();
+}
 
-    {
-        // We insert new waitableTask objects into mPendingTasks in main thread (PostTask()),
-        // and we may remove waitableTask objects from mPendingTasks in either main thread
-        // (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be
-        // protected by a mutex.
-        std::lock_guard<std::mutex> lock(mPendingTasksMutex);
-        mPendingTasks.emplace(waitableTask.Get(), waitableTask);
-    }
+void AsyncTaskManager::PostConstructedTask(Ref<AsyncTask> asyncTask) {
+    // We insert new waitableTask objects into mPendingTasks in main thread (PostTask()),
+    // and we may remove waitableTask objects from mPendingTasks in either main thread
+    // (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be
+    // protected by a mutex.
+    // Hold the mutex until the task is fully posted otherwise it could complete and be deleted
+    // from mPending tasks before it is fully initialized.
+    mPendingTasks.Use(
+        [&asyncTask, taskManager = this, taskPool = mWorkerTaskPool](auto pendingTasks) {
+            // If these allocations becomes expensive, we can slab-allocate tasks.
+            auto iter = pendingTasks->emplace(std::make_unique<WaitableTask>());
 
-    // Ref the task since it is accessed inside the worker function.
-    // The worker function will acquire and release the task upon completion.
-    waitableTask->AddRef();
-    waitableTask->waitableEvent =
-        mWorkerTaskPool->PostWorkerTask(DoWaitableTask, waitableTask.Get());
+            // Should never be inserting the same value twice.
+            DAWN_ASSERT(iter.second);
+
+            WaitableTask* waitableTask = iter.first->get();
+            waitableTask->taskManager = taskManager;
+            waitableTask->asyncTask = asyncTask;
+
+            asyncTask->mWaitableEvent = taskPool->PostWorkerTask(RunTask, waitableTask);
+        });
 }
 
 void AsyncTaskManager::HandleTaskCompletion(WaitableTask* task) {
-    std::lock_guard<std::mutex> lock(mPendingTasksMutex);
-    mPendingTasks.erase(task);
+    DAWN_ASSERT(task);
+    DAWN_ASSERT(task->asyncTask->GetState() == AsyncTaskState::Completed);
+
+    mPendingTasks.Use([&task](auto pendingTasks) { return pendingTasks->erase(task); });
 }
 
 void AsyncTaskManager::WaitAllPendingTasks() {
-    absl::flat_hash_map<WaitableTask*, Ref<WaitableTask>> allPendingTasks;
+    PendingTasksSet allPendingTasks =
+        mPendingTasks.Use([](auto pendingTasks) { return std::move(*pendingTasks); });
 
-    {
-        std::lock_guard<std::mutex> lock(mPendingTasksMutex);
-        allPendingTasks.swap(mPendingTasks);
-    }
-
-    for (auto& [_, task] : allPendingTasks) {
-        task->waitableEvent->Wait();
+    for (auto& task : allPendingTasks) {
+        task->asyncTask->Wait();
     }
 }
 
 bool AsyncTaskManager::HasPendingTasks() {
-    std::lock_guard<std::mutex> lock(mPendingTasksMutex);
-    return !mPendingTasks.empty();
+    return mPendingTasks.Use([](auto pendingTasks) { return !pendingTasks->empty(); });
 }
 
-void AsyncTaskManager::DoWaitableTask(void* task) {
-    Ref<WaitableTask> waitableTask = AcquireRef(static_cast<WaitableTask*>(task));
-    waitableTask->asyncTask();
-    waitableTask->taskManager->HandleTaskCompletion(waitableTask.Get());
+void AsyncTaskManager::RunTask(void* task) {
+    WaitableTask* waitableTask = static_cast<WaitableTask*>(task);
+    waitableTask->asyncTask->Run();
+    waitableTask->taskManager->HandleTaskCompletion(waitableTask);
 }
 
-AsyncTaskManager::WaitableTask::WaitableTask() = default;
-
-AsyncTaskManager::WaitableTask::~WaitableTask() = default;
-
 }  // namespace dawn::native
diff --git a/src/dawn/native/AsyncTask.h b/src/dawn/native/AsyncTask.h
index 6457b92..ef310e2 100644
--- a/src/dawn/native/AsyncTask.h
+++ b/src/dawn/native/AsyncTask.h
@@ -31,8 +31,12 @@
 #include <functional>
 #include <memory>
 #include <mutex>
+#include <utility>
+#include <vector>
 
-#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "dawn/common/MutexProtected.h"
+#include "dawn/common/NonCopyable.h"
 #include "dawn/common/Ref.h"
 #include "dawn/common/RefCounted.h"
 #include "partition_alloc/pointers/raw_ptr.h"
@@ -49,32 +53,74 @@
 // shutting down the device. RunNow() could be used for more advanced scenarios, for example
 // always doing ShaderModule initial compilation asynchronously, but being able to steal the
 // task if we need it for synchronous pipeline compilation.
-using AsyncTask = std::function<void()>;
+
+enum class AsyncTaskState : uint8_t {
+    Pending = 0,
+    Running = 1,
+    Completed = 2,
+};
+
+using AsyncTaskFunction = std::function<void()>;
+using AsyncTaskCompletionCallback = std::function<void()>;
+
+class AsyncTask : public RefCounted {
+  public:
+    explicit AsyncTask(AsyncTaskFunction task);
+
+    AsyncTaskState GetState() const { return mState; }
+
+    void Wait();
+
+    void AddCompletionCallback(AsyncTaskCompletionCallback completionCallback);
+
+  private:
+    // Friends with the task manager to privately manage when tasks are executed.
+    friend class AsyncTaskManager;
+    void Run();
+
+    AsyncTaskFunction mTask;
+
+    // Use a mutex to guard changes to mCompletionCallbacks, mWaitableEvent and transitioning mState
+    // to Completed.
+    // mState is atomic for a lockless getter and Pending -> Running transition.
+    std::mutex mMutex;
+    std::atomic<AsyncTaskState> mState = AsyncTaskState::Pending;
+    std::vector<AsyncTaskCompletionCallback> mCompletionCallbacks;
+
+    // Hold onto the waitable platform event until the task has completed. Released before the
+    // destruction of the AsyncTask to be as light weight as possible.
+    std::unique_ptr<dawn::platform::WaitableEvent> mWaitableEvent;
+};
 
 class AsyncTaskManager {
   public:
     explicit AsyncTaskManager(dawn::platform::WorkerTaskPool* workerTaskPool);
+    ~AsyncTaskManager();
 
-    void PostTask(AsyncTask asyncTask);
+    template <typename TaskType, class... Args>
+    Ref<TaskType> PostTask(Args&&... args) {
+        Ref<TaskType> asyncTask = AcquireRef(new TaskType(std::forward<Args>(args)...));
+        PostConstructedTask(asyncTask);
+        return asyncTask;
+    }
+
     void WaitAllPendingTasks();
     bool HasPendingTasks();
 
   private:
-    class WaitableTask : public RefCounted {
-      public:
-        WaitableTask();
-        ~WaitableTask() override;
-
-        AsyncTask asyncTask;
+    struct WaitableTask : NonCopyable {
+        Ref<AsyncTask> asyncTask;
         raw_ptr<AsyncTaskManager> taskManager;
-        std::unique_ptr<dawn::platform::WaitableEvent> waitableEvent;
     };
 
-    static void DoWaitableTask(void* task);
+    void PostConstructedTask(Ref<AsyncTask> asyncTask);
+
+    static void RunTask(void* task);
     void HandleTaskCompletion(WaitableTask* task);
 
-    std::mutex mPendingTasksMutex;
-    absl::flat_hash_map<WaitableTask*, Ref<WaitableTask>> mPendingTasks;
+    using PendingTasksSet = absl::flat_hash_set<std::unique_ptr<WaitableTask>>;
+    MutexProtected<PendingTasksSet> mPendingTasks;
+
     raw_ptr<dawn::platform::WorkerTaskPool> mWorkerTaskPool;
 };
 
diff --git a/src/dawn/native/CreatePipelineAsyncEvent.cpp b/src/dawn/native/CreatePipelineAsyncEvent.cpp
index 5f8521b..6eb5ce0 100644
--- a/src/dawn/native/CreatePipelineAsyncEvent.cpp
+++ b/src/dawn/native/CreatePipelineAsyncEvent.cpp
@@ -176,7 +176,7 @@
                             "CreatePipelineAsyncEvent::InitializeAsync", this, "label", eventLabel);
 
     auto asyncTask = [event = Ref<CreatePipelineAsyncEvent>(this)] { event->InitializeImpl(true); };
-    device->GetAsyncTaskManager()->PostTask(std::move(asyncTask));
+    device->GetAsyncTaskManager()->PostTask<AsyncTask>(std::move(asyncTask));
 }
 
 template <typename PipelineType, typename CreatePipelineAsyncCallbackInfo>
diff --git a/src/dawn/tests/unittests/AsyncTaskTests.cpp b/src/dawn/tests/unittests/AsyncTaskTests.cpp
index 6db4c33..65a1ace 100644
--- a/src/dawn/tests/unittests/AsyncTaskTests.cpp
+++ b/src/dawn/tests/unittests/AsyncTaskTests.cpp
@@ -40,7 +40,7 @@
 #include "dawn/platform/DawnPlatform.h"
 #include "gtest/gtest.h"
 
-namespace dawn {
+namespace dawn::native {
 namespace {
 
 struct SimpleTaskResult {
@@ -82,14 +82,13 @@
     platform::Platform platform;
     std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
 
-    native::AsyncTaskManager taskManager(pool.get());
+    AsyncTaskManager taskManager(pool.get());
     ConcurrentTaskResultQueue taskResultQueue;
 
     constexpr size_t kTaskCount = 4u;
     std::set<uint32_t> idset;
     for (uint32_t i = 0; i < kTaskCount; ++i) {
-        native::AsyncTask asyncTask([&taskResultQueue, i] { DoTask(&taskResultQueue, i); });
-        taskManager.PostTask(std::move(asyncTask));
+        taskManager.PostTask<AsyncTask>([&taskResultQueue, i] { DoTask(&taskResultQueue, i); });
         idset.insert(i);
     }
 
@@ -103,5 +102,65 @@
     ASSERT_TRUE(idset.empty());
 }
 
+// Test that the task status is updated based on the task's running state
+TEST_F(AsyncTaskTest, Status) {
+    platform::Platform platform;
+    std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
+
+    AsyncTaskManager taskManager(pool.get());
+    ConcurrentTaskResultQueue taskResultQueue;
+
+    // Use a mutex to force the task to wait on the main thread before completing
+    std::mutex mutex;
+    std::unique_lock lock(mutex);
+
+    auto task = taskManager.PostTask<AsyncTask>(
+        [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); });
+
+    ASSERT_NE(task->GetState(), AsyncTaskState::Completed);
+
+    // Allow the task to complete
+    lock.unlock();
+    task->Wait();
+    ASSERT_EQ(task->GetState(), AsyncTaskState::Completed);
+}
+
+// Test coverage of the completion callbacks for tasks
+TEST_F(AsyncTaskTest, Callbacks) {
+    platform::Platform platform;
+    std::unique_ptr<platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
+
+    AsyncTaskManager taskManager(pool.get());
+    ConcurrentTaskResultQueue taskResultQueue;
+
+    // Use a mutex to force the task to wait on the main thread before completing
+    std::mutex mutex;
+    std::unique_lock lock(mutex);
+    auto waitingTaskFunction = taskManager.PostTask<AsyncTask>(
+        [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); });
+
+    // Use a completion callback that simply counts how many times it's been called
+    std::atomic<uint64_t> completionCallbackCounter = 0;
+    AsyncTaskCompletionCallback completionCallback = [&completionCallbackCounter]() {
+        completionCallbackCounter++;
+    };
+
+    // Spawn a task that waits for the mutex. Add a completion callback and confirm that it's not
+    // called before the task completes
+    auto task = taskManager.PostTask<AsyncTask>(
+        [&mutex]() { std::scoped_lock<std::mutex> taskLock(mutex); });
+    task->AddCompletionCallback(completionCallback);
+    EXPECT_EQ(completionCallbackCounter, 0u);
+
+    // Allow the task to complete and expect that the completion callback has been called
+    lock.unlock();
+    task->Wait();
+    EXPECT_EQ(completionCallbackCounter, 1u);
+
+    // Add another completion callback to the already-completed task and check that it is called
+    // immediately
+    task->AddCompletionCallback(completionCallback);
+    EXPECT_EQ(completionCallbackCounter, 2u);
+}
 }  // anonymous namespace
-}  // namespace dawn
+}  // namespace dawn::native
diff --git a/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp b/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp
index 225bfed..7211ac3 100644
--- a/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp
+++ b/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp
@@ -49,12 +49,12 @@
     std::atomic_bool done(false);
 
     // Simulate that an async task would take a long time to finish.
-    dawn::native::AsyncTask asyncTask([&done] {
+    AsyncTaskFunction task([&done] {
         std::this_thread::sleep_for(std::chrono::milliseconds(1000));
         done = true;
     });
 
-    mDeviceMock->GetAsyncTaskManager()->PostTask(std::move(asyncTask));
+    mDeviceMock->GetAsyncTaskManager()->PostTask<AsyncTask>(std::move(task));
     DropDevice();
     // Dropping the device should force the async task to finish.
     EXPECT_TRUE(done.load());