Implement WaitableEvent and WorkerTaskPool for multi-threaded tasks
This patch adds the basic implementation of WaitableEvent and
WorkerTaskPool for multi-threaded tasks in Dawn (for example, the
multi-threaded implementation of CreateReady*Pipeline()).
BUG=dawn:529
TEST=dawn_unittests
Change-Id: Ibf84348f4c0f0d26badc19ae94cd536cef89d084
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/36360
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/common/BUILD.gn b/src/common/BUILD.gn
index 8db531f..761ecd0 100644
--- a/src/common/BUILD.gn
+++ b/src/common/BUILD.gn
@@ -171,6 +171,7 @@
"Math.cpp",
"Math.h",
"NSRef.h",
+ "NonCopyable.h",
"PlacementAllocated.h",
"Platform.h",
"RefBase.h",
diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt
index 3b28bba..46c1737 100644
--- a/src/common/CMakeLists.txt
+++ b/src/common/CMakeLists.txt
@@ -33,6 +33,7 @@
"Math.cpp"
"Math.h"
"NSRef.h"
+ "NonCopyable.h"
"PlacementAllocated.h"
"Platform.h"
"RefBase.h"
diff --git a/src/common/NonCopyable.h b/src/common/NonCopyable.h
new file mode 100644
index 0000000..e711f71
--- /dev/null
+++ b/src/common/NonCopyable.h
@@ -0,0 +1,32 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef COMMON_NONCOPYABLE_H_
+#define COMMON_NONCOPYABLE_H_
+
+// NonCopyable:
+// the base class for the classes that are not copyable.
+//
+
+class NonCopyable {
+ protected:
+ constexpr NonCopyable() = default;
+ ~NonCopyable() = default;
+
+ private:
+ NonCopyable(const NonCopyable&) = delete;
+ void operator=(const NonCopyable&) = delete;
+};
+
+#endif
diff --git a/src/dawn_platform/BUILD.gn b/src/dawn_platform/BUILD.gn
index 91c9e75..cbd3238 100644
--- a/src/dawn_platform/BUILD.gn
+++ b/src/dawn_platform/BUILD.gn
@@ -25,6 +25,8 @@
"${dawn_root}/src/include/dawn_platform/DawnPlatform.h",
"${dawn_root}/src/include/dawn_platform/dawn_platform_export.h",
"DawnPlatform.cpp",
+ "WorkerThread.cpp",
+ "WorkerThread.h",
"tracing/EventTracer.cpp",
"tracing/EventTracer.h",
"tracing/TraceEvent.h",
diff --git a/src/dawn_platform/CMakeLists.txt b/src/dawn_platform/CMakeLists.txt
index b8075e2..92372bb 100644
--- a/src/dawn_platform/CMakeLists.txt
+++ b/src/dawn_platform/CMakeLists.txt
@@ -23,6 +23,8 @@
"${DAWN_INCLUDE_DIR}/dawn_platform/DawnPlatform.h"
"${DAWN_INCLUDE_DIR}/dawn_platform/dawn_platform_export.h"
"DawnPlatform.cpp"
+ "WorkerThread.cpp"
+ "WorkerThread.h"
"tracing/EventTracer.cpp"
"tracing/EventTracer.h"
"tracing/TraceEvent.h"
diff --git a/src/dawn_platform/DawnPlatform.cpp b/src/dawn_platform/DawnPlatform.cpp
index b772bac..1bedbcb 100644
--- a/src/dawn_platform/DawnPlatform.cpp
+++ b/src/dawn_platform/DawnPlatform.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include "dawn_platform/DawnPlatform.h"
+#include "dawn_platform/WorkerThread.h"
#include "common/Assert.h"
@@ -55,4 +56,8 @@
return nullptr;
}
+ std::unique_ptr<dawn_platform::WorkerTaskPool> Platform::CreateWorkerTaskPool() {
+ return std::make_unique<AsyncWorkerThreadPool>();
+ }
+
} // namespace dawn_platform
diff --git a/src/dawn_platform/WorkerThread.cpp b/src/dawn_platform/WorkerThread.cpp
new file mode 100644
index 0000000..64d09f1
--- /dev/null
+++ b/src/dawn_platform/WorkerThread.cpp
@@ -0,0 +1,51 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_platform/WorkerThread.h"
+
+#include <future>
+
+#include "common/Assert.h"
+
+namespace {
+
+ class AsyncWaitableEvent final : public dawn_platform::WaitableEvent {
+ public:
+ explicit AsyncWaitableEvent(std::function<void()> func) {
+ mFuture = std::async(std::launch::async, func);
+ }
+ virtual ~AsyncWaitableEvent() override {
+ ASSERT(IsComplete());
+ }
+ void Wait() override {
+ ASSERT(mFuture.valid());
+ mFuture.wait();
+ }
+ bool IsComplete() override {
+ ASSERT(mFuture.valid());
+ return mFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
+ }
+
+ private:
+ std::future<void> mFuture;
+ };
+
+} // anonymous namespace
+
+std::unique_ptr<dawn_platform::WaitableEvent> AsyncWorkerThreadPool::PostWorkerTask(
+ dawn_platform::PostWorkerTaskCallback callback,
+ void* userdata) {
+ std::function<void()> doTask = [callback, userdata]() { callback(userdata); };
+ return std::make_unique<AsyncWaitableEvent>(doTask);
+}
\ No newline at end of file
diff --git a/src/dawn_platform/WorkerThread.h b/src/dawn_platform/WorkerThread.h
new file mode 100644
index 0000000..56a5d10
--- /dev/null
+++ b/src/dawn_platform/WorkerThread.h
@@ -0,0 +1,28 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef COMMON_WORKERTHREAD_H_
+#define COMMON_WORKERTHREAD_H_
+
+#include "common/NonCopyable.h"
+#include "dawn_platform/DawnPlatform.h"
+
+class AsyncWorkerThreadPool : public dawn_platform::WorkerTaskPool, public NonCopyable {
+ public:
+ std::unique_ptr<dawn_platform::WaitableEvent> PostWorkerTask(
+ dawn_platform::PostWorkerTaskCallback callback,
+ void* userdata) override;
+};
+
+#endif
diff --git a/src/include/dawn_platform/DawnPlatform.h b/src/include/dawn_platform/DawnPlatform.h
index 4a00f53..3a28419 100644
--- a/src/include/dawn_platform/DawnPlatform.h
+++ b/src/include/dawn_platform/DawnPlatform.h
@@ -19,6 +19,7 @@
#include <cstddef>
#include <cstdint>
+#include <memory>
#include <dawn/webgpu.h>
@@ -60,6 +61,24 @@
CachingInterface& operator=(const CachingInterface&) = delete;
};
+ class DAWN_PLATFORM_EXPORT WaitableEvent {
+ public:
+ WaitableEvent() = default;
+ virtual ~WaitableEvent() = default;
+ virtual void Wait() = 0; // Wait for completion
+ virtual bool IsComplete() = 0; // Non-blocking check if the event is complete
+ };
+
+ using PostWorkerTaskCallback = void (*)(void* userdata);
+
+ class DAWN_PLATFORM_EXPORT WorkerTaskPool {
+ public:
+ WorkerTaskPool() = default;
+ virtual ~WorkerTaskPool() = default;
+ virtual std::unique_ptr<WaitableEvent> PostWorkerTask(PostWorkerTaskCallback,
+ void* userdata) = 0;
+ };
+
class DAWN_PLATFORM_EXPORT Platform {
public:
Platform();
@@ -85,6 +104,7 @@
// device which uses it to persistently cache objects.
virtual CachingInterface* GetCachingInterface(const void* fingerprint,
size_t fingerprintSize);
+ virtual std::unique_ptr<WorkerTaskPool> CreateWorkerTaskPool();
private:
Platform(const Platform&) = delete;
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index 3971b09..daabd3b 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -181,6 +181,7 @@
"unittests/SystemUtilsTests.cpp",
"unittests/ToBackendTests.cpp",
"unittests/TypedIntegerTests.cpp",
+ "unittests/WorkerThreadTests.cpp",
"unittests/validation/BindGroupValidationTests.cpp",
"unittests/validation/BufferValidationTests.cpp",
"unittests/validation/CommandBufferValidationTests.cpp",
diff --git a/src/tests/unittests/WorkerThreadTests.cpp b/src/tests/unittests/WorkerThreadTests.cpp
new file mode 100644
index 0000000..8faee5d
--- /dev/null
+++ b/src/tests/unittests/WorkerThreadTests.cpp
@@ -0,0 +1,169 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// WorkerThreadTests:
+// Simple tests for the worker thread class.
+
+#include <gtest/gtest.h>
+
+#include <list>
+#include <memory>
+#include <mutex>
+#include <queue>
+
+#include "common/NonCopyable.h"
+#include "dawn_platform/DawnPlatform.h"
+
+namespace {
+
+ struct SimpleTaskResult {
+ uint32_t id;
+ bool isDone = false;
+ };
+
+ // A thread-safe queue that stores the task results.
+ class ConcurrentTaskResultQueue : public NonCopyable {
+ public:
+ void TaskCompleted(const SimpleTaskResult& result) {
+ ASSERT_TRUE(result.isDone);
+
+ std::lock_guard<std::mutex> lock(mMutex);
+ mTaskResultQueue.push(result);
+ }
+
+ std::vector<SimpleTaskResult> GetAndPopCompletedTasks() {
+ std::lock_guard<std::mutex> lock(mMutex);
+
+ std::vector<SimpleTaskResult> results;
+ while (!mTaskResultQueue.empty()) {
+ results.push_back(mTaskResultQueue.front());
+ mTaskResultQueue.pop();
+ }
+ return results;
+ }
+
+ private:
+ std::mutex mMutex;
+ std::queue<SimpleTaskResult> mTaskResultQueue;
+ };
+
+ // A simple task that will be executed asynchronously with pool->PostWorkerTask().
+ class SimpleTask : public NonCopyable {
+ public:
+ SimpleTask(uint32_t id, ConcurrentTaskResultQueue* resultQueue)
+ : mId(id), mResultQueue(resultQueue) {
+ }
+
+ private:
+ friend class Tracker;
+
+ static void DoTaskOnWorkerTaskPool(void* task) {
+ SimpleTask* simpleTaskPtr = static_cast<SimpleTask*>(task);
+ simpleTaskPtr->doTask();
+ }
+
+ void doTask() {
+ SimpleTaskResult result;
+ result.id = mId;
+ result.isDone = true;
+ mResultQueue->TaskCompleted(result);
+ }
+
+ uint32_t mId;
+ ConcurrentTaskResultQueue* mResultQueue;
+ };
+
+ // A simple implementation of task tracker which is only called in main thread and not
+ // thread-safe.
+ class Tracker : public NonCopyable {
+ public:
+ explicit Tracker(dawn_platform::WorkerTaskPool* pool) : mPool(pool) {
+ }
+
+ void StartNewTask(uint32_t taskId) {
+ mTasksInFlight.emplace_back(this, mPool, taskId);
+ }
+
+ uint64_t GetTasksInFlightCount() {
+ return mTasksInFlight.size();
+ }
+
+ void WaitAll() {
+ for (auto iter = mTasksInFlight.begin(); iter != mTasksInFlight.end(); ++iter) {
+ iter->waitableEvent->Wait();
+ }
+ }
+
+ // In Tick() we clean up all the completed tasks and consume all the available results.
+ void Tick() {
+ auto iter = mTasksInFlight.begin();
+ while (iter != mTasksInFlight.end()) {
+ if (iter->waitableEvent->IsComplete()) {
+ iter = mTasksInFlight.erase(iter);
+ } else {
+ ++iter;
+ }
+ }
+
+ const std::vector<SimpleTaskResult>& results =
+ mCompletedTaskResultQueue.GetAndPopCompletedTasks();
+ for (const SimpleTaskResult& result : results) {
+ EXPECT_TRUE(result.isDone);
+ }
+ }
+
+ private:
+ SimpleTask* CreateSimpleTask(uint32_t taskId) {
+ return new SimpleTask(taskId, &mCompletedTaskResultQueue);
+ }
+
+ struct WaitableTask {
+ WaitableTask(Tracker* tracker, dawn_platform::WorkerTaskPool* pool, uint32_t taskId) {
+ task.reset(tracker->CreateSimpleTask(taskId));
+ waitableEvent =
+ pool->PostWorkerTask(SimpleTask::DoTaskOnWorkerTaskPool, task.get());
+ }
+
+ std::unique_ptr<SimpleTask> task;
+ std::unique_ptr<dawn_platform::WaitableEvent> waitableEvent;
+ };
+
+ dawn_platform::WorkerTaskPool* mPool;
+
+ std::list<WaitableTask> mTasksInFlight;
+ ConcurrentTaskResultQueue mCompletedTaskResultQueue;
+ };
+
+} // anonymous namespace
+
+class WorkerThreadTest : public testing::Test {};
+
+// Emulate the basic usage of worker thread pool in CreateReady*Pipeline().
+TEST_F(WorkerThreadTest, Basic) {
+ dawn_platform::Platform platform;
+ std::unique_ptr<dawn_platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
+ Tracker tracker(pool.get());
+
+ constexpr uint32_t kTaskCount = 4;
+ for (uint32_t i = 0; i < kTaskCount; ++i) {
+ tracker.StartNewTask(i);
+ }
+ EXPECT_EQ(kTaskCount, tracker.GetTasksInFlightCount());
+
+ // Wait for the completion of all the tasks.
+ tracker.WaitAll();
+
+ tracker.Tick();
+ EXPECT_EQ(0u, tracker.GetTasksInFlightCount());
+}