blob: 996d7e1eecdf33d520e18caf1b52562d9a35fbef [file] [log] [blame]
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// WorkerThreadTests:
// Simple tests for the worker thread class.
#include <gtest/gtest.h>
#include <list>
#include <memory>
#include <mutex>
#include <queue>
#include "common/NonCopyable.h"
#include "dawn_platform/DawnPlatform.h"
namespace {
struct SimpleTaskResult {
uint32_t id;
bool isDone = false;
};
// A thread-safe queue that stores the task results.
class ConcurrentTaskResultQueue : public NonCopyable {
public:
void TaskCompleted(const SimpleTaskResult& result) {
ASSERT_TRUE(result.isDone);
std::lock_guard<std::mutex> lock(mMutex);
mTaskResultQueue.push(result);
}
std::vector<SimpleTaskResult> GetAndPopCompletedTasks() {
std::lock_guard<std::mutex> lock(mMutex);
std::vector<SimpleTaskResult> results;
while (!mTaskResultQueue.empty()) {
results.push_back(mTaskResultQueue.front());
mTaskResultQueue.pop();
}
return results;
}
private:
std::mutex mMutex;
std::queue<SimpleTaskResult> mTaskResultQueue;
};
// A simple task that will be executed asynchronously with pool->PostWorkerTask().
class SimpleTask : public NonCopyable {
public:
SimpleTask(uint32_t id, ConcurrentTaskResultQueue* resultQueue)
: mId(id), mResultQueue(resultQueue) {
}
private:
friend class Tracker;
static void DoTaskOnWorkerTaskPool(void* task) {
SimpleTask* simpleTaskPtr = static_cast<SimpleTask*>(task);
simpleTaskPtr->doTask();
}
void doTask() {
SimpleTaskResult result;
result.id = mId;
result.isDone = true;
mResultQueue->TaskCompleted(result);
}
uint32_t mId;
ConcurrentTaskResultQueue* mResultQueue;
};
// A simple implementation of task tracker which is only called in main thread and not
// thread-safe.
class Tracker : public NonCopyable {
public:
explicit Tracker(dawn_platform::WorkerTaskPool* pool) : mPool(pool) {
}
void StartNewTask(uint32_t taskId) {
mTasksInFlight.emplace_back(this, mPool, taskId);
}
uint64_t GetTasksInFlightCount() {
return mTasksInFlight.size();
}
void WaitAll() {
for (auto iter = mTasksInFlight.begin(); iter != mTasksInFlight.end(); ++iter) {
iter->waitableEvent->Wait();
}
}
// In Tick() we clean up all the completed tasks and consume all the available results.
void Tick() {
auto iter = mTasksInFlight.begin();
while (iter != mTasksInFlight.end()) {
if (iter->waitableEvent->IsComplete()) {
iter = mTasksInFlight.erase(iter);
} else {
++iter;
}
}
const std::vector<SimpleTaskResult>& results =
mCompletedTaskResultQueue.GetAndPopCompletedTasks();
for (const SimpleTaskResult& result : results) {
EXPECT_TRUE(result.isDone);
}
}
private:
SimpleTask* CreateSimpleTask(uint32_t taskId) {
return new SimpleTask(taskId, &mCompletedTaskResultQueue);
}
struct WaitableTask {
WaitableTask(Tracker* tracker, dawn_platform::WorkerTaskPool* pool, uint32_t taskId) {
task.reset(tracker->CreateSimpleTask(taskId));
waitableEvent =
pool->PostWorkerTask(SimpleTask::DoTaskOnWorkerTaskPool, task.get());
}
std::unique_ptr<SimpleTask> task;
std::unique_ptr<dawn_platform::WaitableEvent> waitableEvent;
};
dawn_platform::WorkerTaskPool* mPool;
std::list<WaitableTask> mTasksInFlight;
ConcurrentTaskResultQueue mCompletedTaskResultQueue;
};
} // anonymous namespace
class WorkerThreadTest : public testing::Test {};
// Emulate the basic usage of worker thread pool in Create*PipelineAsync().
TEST_F(WorkerThreadTest, Basic) {
dawn_platform::Platform platform;
std::unique_ptr<dawn_platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
Tracker tracker(pool.get());
constexpr uint32_t kTaskCount = 4;
for (uint32_t i = 0; i < kTaskCount; ++i) {
tracker.StartNewTask(i);
}
EXPECT_EQ(kTaskCount, tracker.GetTasksInFlightCount());
// Wait for the completion of all the tasks.
tracker.WaitAll();
tracker.Tick();
EXPECT_EQ(0u, tracker.GetTasksInFlightCount());
}