Implement CallbackTaskManager for Create*PipelineAsync This patch implements CallbackTask and CallbackTaskManager to store the callbacks of Create*PipelineAsync(). In the futureCallbackTaskManager will manage all the callbacks that should be called in Device.Tick(). BUG=dawn:529 Change-Id: I6ad4352371eb44515bc2d85cdc68220c9b758b8e Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/49060 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn index 711fa8c..8e4f0c4 100644 --- a/src/dawn_native/BUILD.gn +++ b/src/dawn_native/BUILD.gn
@@ -184,6 +184,8 @@ "Buffer.h", "CachedObject.cpp", "CachedObject.h", + "CallbackTaskManager.cpp", + "CallbackTaskManager.h", "CommandAllocator.cpp", "CommandAllocator.h", "CommandBuffer.cpp", @@ -204,8 +206,8 @@ "ComputePipeline.h", "CopyTextureForBrowserHelper.cpp", "CopyTextureForBrowserHelper.h", - "CreatePipelineAsyncTracker.cpp", - "CreatePipelineAsyncTracker.h", + "CreatePipelineAsyncTask.cpp", + "CreatePipelineAsyncTask.h", "Device.cpp", "Device.h", "DynamicUploader.cpp",
diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt index ba6cec9..b0d470b 100644 --- a/src/dawn_native/CMakeLists.txt +++ b/src/dawn_native/CMakeLists.txt
@@ -50,6 +50,8 @@ "Buffer.h" "CachedObject.cpp" "CachedObject.h" + "CallbackTaskManager.cpp" + "CallbackTaskManager.h" "CommandAllocator.cpp" "CommandAllocator.h" "CommandBuffer.cpp" @@ -70,8 +72,8 @@ "ComputePipeline.h" "CopyTextureForBrowserHelper.cpp" "CopyTextureForBrowserHelper.h" - "CreatePipelineAsyncTracker.cpp" - "CreatePipelineAsyncTracker.h" + "CreatePipelineAsyncTask.cpp" + "CreatePipelineAsyncTask.h" "Device.cpp" "Device.h" "DynamicUploader.cpp"
diff --git a/src/dawn_native/CallbackTaskManager.cpp b/src/dawn_native/CallbackTaskManager.cpp new file mode 100644 index 0000000..1c9106c --- /dev/null +++ b/src/dawn_native/CallbackTaskManager.cpp
@@ -0,0 +1,37 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn_native/CallbackTaskManager.h" + +namespace dawn_native { + + bool CallbackTaskManager::IsEmpty() { + std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex); + return mCallbackTaskQueue.empty(); + } + + std::vector<std::unique_ptr<CallbackTask>> CallbackTaskManager::AcquireCallbackTasks() { + std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex); + + std::vector<std::unique_ptr<CallbackTask>> allTasks; + allTasks.swap(mCallbackTaskQueue); + return allTasks; + } + + void CallbackTaskManager::AddCallbackTask(std::unique_ptr<CallbackTask> callbackTask) { + std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex); + mCallbackTaskQueue.push_back(std::move(callbackTask)); + } + +} // namespace dawn_native
diff --git a/src/dawn_native/CallbackTaskManager.h b/src/dawn_native/CallbackTaskManager.h new file mode 100644 index 0000000..1be0eb2 --- /dev/null +++ b/src/dawn_native/CallbackTaskManager.h
@@ -0,0 +1,47 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNNATIVE_CALLBACK_TASK_MANAGER_H_ +#define DAWNNATIVE_CALLBACK_TASK_MANAGER_H_ + +#include <memory> +#include <mutex> +#include <vector> + +namespace dawn_native { + + class CallbackTaskManager; + + struct CallbackTask { + public: + virtual ~CallbackTask() = default; + virtual void Finish() = 0; + virtual void HandleShutDown() = 0; + virtual void HandleDeviceLoss() = 0; + }; + + class CallbackTaskManager { + public: + void AddCallbackTask(std::unique_ptr<CallbackTask> callbackTask); + bool IsEmpty(); + std::vector<std::unique_ptr<CallbackTask>> AcquireCallbackTasks(); + + private: + std::mutex mCallbackTaskQueueMutex; + std::vector<std::unique_ptr<CallbackTask>> mCallbackTaskQueue; + }; + +} // namespace dawn_native + +#endif
diff --git a/src/dawn_native/CreatePipelineAsyncTask.cpp b/src/dawn_native/CreatePipelineAsyncTask.cpp new file mode 100644 index 0000000..b6a32b1 --- /dev/null +++ b/src/dawn_native/CreatePipelineAsyncTask.cpp
@@ -0,0 +1,103 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn_native/CreatePipelineAsyncTask.h" + +#include "dawn_native/ComputePipeline.h" +#include "dawn_native/Device.h" +#include "dawn_native/RenderPipeline.h" + +namespace dawn_native { + + CreatePipelineAsyncCallbackTaskBase::CreatePipelineAsyncCallbackTaskBase( + std::string errorMessage, + void* userdata) + : mErrorMessage(errorMessage), mUserData(userdata) { + } + + CreateComputePipelineAsyncCallbackTask::CreateComputePipelineAsyncCallbackTask( + Ref<ComputePipelineBase> pipeline, + std::string errorMessage, + WGPUCreateComputePipelineAsyncCallback callback, + void* userdata) + : CreatePipelineAsyncCallbackTaskBase(errorMessage, userdata), + mPipeline(std::move(pipeline)), + mCreateComputePipelineAsyncCallback(callback) { + } + + void CreateComputePipelineAsyncCallbackTask::Finish() { + ASSERT(mCreateComputePipelineAsyncCallback != nullptr); + + if (mPipeline.Get() != nullptr) { + mCreateComputePipelineAsyncCallback( + WGPUCreatePipelineAsyncStatus_Success, + reinterpret_cast<WGPUComputePipeline>(mPipeline.Detach()), "", mUserData); + } else { + mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_Error, nullptr, + mErrorMessage.c_str(), mUserData); + } + } + + void CreateComputePipelineAsyncCallbackTask::HandleShutDown() { + ASSERT(mCreateComputePipelineAsyncCallback != nullptr); + + mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, + "Device destroyed before callback", mUserData); + } + + void CreateComputePipelineAsyncCallbackTask::HandleDeviceLoss() { + ASSERT(mCreateComputePipelineAsyncCallback != nullptr); + + mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, + "Device lost before callback", mUserData); + } + + CreateRenderPipelineAsyncCallbackTask::CreateRenderPipelineAsyncCallbackTask( + Ref<RenderPipelineBase> pipeline, + std::string errorMessage, + WGPUCreateRenderPipelineAsyncCallback callback, + void* userdata) + : CreatePipelineAsyncCallbackTaskBase(errorMessage, userdata), + mPipeline(std::move(pipeline)), + mCreateRenderPipelineAsyncCallback(callback) { + } + + void CreateRenderPipelineAsyncCallbackTask::Finish() { + ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); + + if (mPipeline.Get() != nullptr) { + mCreateRenderPipelineAsyncCallback( + WGPUCreatePipelineAsyncStatus_Success, + reinterpret_cast<WGPURenderPipeline>(mPipeline.Detach()), "", mUserData); + } else { + mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_Error, nullptr, + mErrorMessage.c_str(), mUserData); + } + } + + void CreateRenderPipelineAsyncCallbackTask::HandleShutDown() { + ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); + + mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, + "Device destroyed before callback", mUserData); + } + + void CreateRenderPipelineAsyncCallbackTask::HandleDeviceLoss() { + ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); + + mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, + "Device lost before callback", mUserData); + } + +} // namespace dawn_native
diff --git a/src/dawn_native/CreatePipelineAsyncTask.h b/src/dawn_native/CreatePipelineAsyncTask.h new file mode 100644 index 0000000..9cddfa2 --- /dev/null +++ b/src/dawn_native/CreatePipelineAsyncTask.h
@@ -0,0 +1,68 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNNATIVE_CREATEPIPELINEASYNCTASK_H_ +#define DAWNNATIVE_CREATEPIPELINEASYNCTASK_H_ + +#include "common/RefCounted.h" +#include "dawn/webgpu.h" +#include "dawn_native/CallbackTaskManager.h" + +namespace dawn_native { + + class ComputePipelineBase; + class DeviceBase; + class RenderPipelineBase; + + struct CreatePipelineAsyncCallbackTaskBase : CallbackTask { + CreatePipelineAsyncCallbackTaskBase(std::string errorMessage, void* userData); + + protected: + std::string mErrorMessage; + void* mUserData; + }; + + struct CreateComputePipelineAsyncCallbackTask final : CreatePipelineAsyncCallbackTaskBase { + CreateComputePipelineAsyncCallbackTask(Ref<ComputePipelineBase> pipeline, + std::string errorMessage, + WGPUCreateComputePipelineAsyncCallback callback, + void* userdata); + + void Finish() final; + void HandleShutDown() final; + void HandleDeviceLoss() final; + + private: + Ref<ComputePipelineBase> mPipeline; + WGPUCreateComputePipelineAsyncCallback mCreateComputePipelineAsyncCallback; + }; + + struct CreateRenderPipelineAsyncCallbackTask final : CreatePipelineAsyncCallbackTaskBase { + CreateRenderPipelineAsyncCallbackTask(Ref<RenderPipelineBase> pipeline, + std::string errorMessage, + WGPUCreateRenderPipelineAsyncCallback callback, + void* userdata); + + void Finish() final; + void HandleShutDown() final; + void HandleDeviceLoss() final; + + private: + Ref<RenderPipelineBase> mPipeline; + WGPUCreateRenderPipelineAsyncCallback mCreateRenderPipelineAsyncCallback; + }; + +} // namespace dawn_native + +#endif // DAWNNATIVE_CREATEPIPELINEASYNCTASK_H_
diff --git a/src/dawn_native/CreatePipelineAsyncTracker.cpp b/src/dawn_native/CreatePipelineAsyncTracker.cpp deleted file mode 100644 index 23b8310..0000000 --- a/src/dawn_native/CreatePipelineAsyncTracker.cpp +++ /dev/null
@@ -1,149 +0,0 @@ -// Copyright 2020 The Dawn Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dawn_native/CreatePipelineAsyncTracker.h" - -#include "dawn_native/ComputePipeline.h" -#include "dawn_native/Device.h" -#include "dawn_native/RenderPipeline.h" - -namespace dawn_native { - - CreatePipelineAsyncTaskBase::CreatePipelineAsyncTaskBase(std::string errorMessage, - void* userdata) - : mErrorMessage(errorMessage), mUserData(userdata) { - } - - CreatePipelineAsyncTaskBase::~CreatePipelineAsyncTaskBase() { - } - - CreateComputePipelineAsyncTask::CreateComputePipelineAsyncTask( - Ref<ComputePipelineBase> pipeline, - std::string errorMessage, - WGPUCreateComputePipelineAsyncCallback callback, - void* userdata) - : CreatePipelineAsyncTaskBase(errorMessage, userdata), - mPipeline(std::move(pipeline)), - mCreateComputePipelineAsyncCallback(callback) { - } - - void CreateComputePipelineAsyncTask::Finish() { - ASSERT(mCreateComputePipelineAsyncCallback != nullptr); - - if (mPipeline.Get() != nullptr) { - mCreateComputePipelineAsyncCallback( - WGPUCreatePipelineAsyncStatus_Success, - reinterpret_cast<WGPUComputePipeline>(mPipeline.Detach()), "", mUserData); - } else { - mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_Error, nullptr, - mErrorMessage.c_str(), mUserData); - } - } - - void CreateComputePipelineAsyncTask::HandleShutDown() { - ASSERT(mCreateComputePipelineAsyncCallback != nullptr); - - mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, - "Device destroyed before callback", mUserData); - } - - void CreateComputePipelineAsyncTask::HandleDeviceLoss() { - ASSERT(mCreateComputePipelineAsyncCallback != nullptr); - - mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, - "Device lost before callback", mUserData); - } - - CreateRenderPipelineAsyncTask::CreateRenderPipelineAsyncTask( - Ref<RenderPipelineBase> pipeline, - std::string errorMessage, - WGPUCreateRenderPipelineAsyncCallback callback, - void* userdata) - : CreatePipelineAsyncTaskBase(errorMessage, userdata), - mPipeline(std::move(pipeline)), - mCreateRenderPipelineAsyncCallback(callback) { - } - - void CreateRenderPipelineAsyncTask::Finish() { - ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); - - if (mPipeline.Get() != nullptr) { - mCreateRenderPipelineAsyncCallback( - WGPUCreatePipelineAsyncStatus_Success, - reinterpret_cast<WGPURenderPipeline>(mPipeline.Detach()), "", mUserData); - } else { - mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_Error, nullptr, - mErrorMessage.c_str(), mUserData); - } - } - - void CreateRenderPipelineAsyncTask::HandleShutDown() { - ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); - - mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, - "Device destroyed before callback", mUserData); - } - - void CreateRenderPipelineAsyncTask::HandleDeviceLoss() { - ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); - - mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, - "Device lost before callback", mUserData); - } - - CreatePipelineAsyncTracker::CreatePipelineAsyncTracker(DeviceBase* device) : mDevice(device) { - } - - CreatePipelineAsyncTracker::~CreatePipelineAsyncTracker() { - ASSERT(mCreatePipelineAsyncTasksInFlight.Empty()); - } - - void CreatePipelineAsyncTracker::TrackTask(std::unique_ptr<CreatePipelineAsyncTaskBase> task, - ExecutionSerial serial) { - mCreatePipelineAsyncTasksInFlight.Enqueue(std::move(task), serial); - mDevice->AddFutureSerial(serial); - } - - void CreatePipelineAsyncTracker::Tick(ExecutionSerial finishedSerial) { - // If a user calls Queue::Submit inside Create*PipelineAsync, then the device will be - // ticked, which in turns ticks the tracker, causing reentrance here. To prevent the - // reentrant call from invalidating mCreatePipelineAsyncTasksInFlight while in use by the - // first call, we remove the tasks to finish from the queue, update - // mCreatePipelineAsyncTasksInFlight, then run the callbacks. - std::vector<std::unique_ptr<CreatePipelineAsyncTaskBase>> tasks; - for (auto& task : mCreatePipelineAsyncTasksInFlight.IterateUpTo(finishedSerial)) { - tasks.push_back(std::move(task)); - } - mCreatePipelineAsyncTasksInFlight.ClearUpTo(finishedSerial); - - for (auto& task : tasks) { - task->Finish(); - } - } - - void CreatePipelineAsyncTracker::ClearForShutDown() { - for (auto& task : mCreatePipelineAsyncTasksInFlight.IterateAll()) { - task->HandleShutDown(); - } - mCreatePipelineAsyncTasksInFlight.Clear(); - } - - void CreatePipelineAsyncTracker::ClearForDeviceLoss() { - for (auto& task : mCreatePipelineAsyncTasksInFlight.IterateAll()) { - task->HandleDeviceLoss(); - } - mCreatePipelineAsyncTasksInFlight.Clear(); - } - -} // namespace dawn_native
diff --git a/src/dawn_native/CreatePipelineAsyncTracker.h b/src/dawn_native/CreatePipelineAsyncTracker.h deleted file mode 100644 index 738d719..0000000 --- a/src/dawn_native/CreatePipelineAsyncTracker.h +++ /dev/null
@@ -1,93 +0,0 @@ -// Copyright 2020 The Dawn Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DAWNNATIVE_CREATEPIPELINEASYNCTRACKER_H_ -#define DAWNNATIVE_CREATEPIPELINEASYNCTRACKER_H_ - -#include "common/RefCounted.h" -#include "common/SerialQueue.h" -#include "dawn/webgpu.h" -#include "dawn_native/IntegerTypes.h" - -#include <memory> -#include <string> - -namespace dawn_native { - - class ComputePipelineBase; - class DeviceBase; - class RenderPipelineBase; - - struct CreatePipelineAsyncTaskBase { - CreatePipelineAsyncTaskBase(std::string errorMessage, void* userData); - virtual ~CreatePipelineAsyncTaskBase(); - - virtual void Finish() = 0; - virtual void HandleShutDown() = 0; - virtual void HandleDeviceLoss() = 0; - - protected: - std::string mErrorMessage; - void* mUserData; - }; - - struct CreateComputePipelineAsyncTask final : public CreatePipelineAsyncTaskBase { - CreateComputePipelineAsyncTask(Ref<ComputePipelineBase> pipeline, - std::string errorMessage, - WGPUCreateComputePipelineAsyncCallback callback, - void* userdata); - - void Finish() final; - void HandleShutDown() final; - void HandleDeviceLoss() final; - - private: - Ref<ComputePipelineBase> mPipeline; - WGPUCreateComputePipelineAsyncCallback mCreateComputePipelineAsyncCallback; - }; - - struct CreateRenderPipelineAsyncTask final : public CreatePipelineAsyncTaskBase { - CreateRenderPipelineAsyncTask(Ref<RenderPipelineBase> pipeline, - std::string errorMessage, - WGPUCreateRenderPipelineAsyncCallback callback, - void* userdata); - - void Finish() final; - void HandleShutDown() final; - void HandleDeviceLoss() final; - - private: - Ref<RenderPipelineBase> mPipeline; - WGPUCreateRenderPipelineAsyncCallback mCreateRenderPipelineAsyncCallback; - }; - - class CreatePipelineAsyncTracker { - public: - explicit CreatePipelineAsyncTracker(DeviceBase* device); - ~CreatePipelineAsyncTracker(); - - void TrackTask(std::unique_ptr<CreatePipelineAsyncTaskBase> task, ExecutionSerial serial); - void Tick(ExecutionSerial finishedSerial); - void ClearForShutDown(); - void ClearForDeviceLoss(); - - private: - DeviceBase* mDevice; - SerialQueue<ExecutionSerial, std::unique_ptr<CreatePipelineAsyncTaskBase>> - mCreatePipelineAsyncTasksInFlight; - }; - -} // namespace dawn_native - -#endif // DAWNNATIVE_CREATEPIPELINEASYNCTRACKER_H_
diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index 713b88e..a2443b2 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp
@@ -20,11 +20,12 @@ #include "dawn_native/BindGroup.h" #include "dawn_native/BindGroupLayout.h" #include "dawn_native/Buffer.h" +#include "dawn_native/CallbackTaskManager.h" #include "dawn_native/CommandBuffer.h" #include "dawn_native/CommandEncoder.h" #include "dawn_native/CompilationMessages.h" #include "dawn_native/ComputePipeline.h" -#include "dawn_native/CreatePipelineAsyncTracker.h" +#include "dawn_native/CreatePipelineAsyncTask.h" #include "dawn_native/DynamicUploader.h" #include "dawn_native/ErrorData.h" #include "dawn_native/ErrorScope.h" @@ -125,7 +126,7 @@ mCaches = std::make_unique<DeviceBase::Caches>(); mErrorScopeStack = std::make_unique<ErrorScopeStack>(); mDynamicUploader = std::make_unique<DynamicUploader>(this); - mCreatePipelineAsyncTracker = std::make_unique<CreatePipelineAsyncTracker>(this); + mCallbackTaskManager = std::make_unique<CallbackTaskManager>(); mDeprecationWarnings = std::make_unique<DeprecationWarnings>(); mInternalPipelineStore = std::make_unique<InternalPipelineStore>(); mPersistentCache = std::make_unique<PersistentCache>(this); @@ -142,8 +143,11 @@ void DeviceBase::ShutDownBase() { // Skip handling device facilities if they haven't even been created (or failed doing so) if (mState != State::BeingCreated) { - // Reject all async pipeline creations. - mCreatePipelineAsyncTracker->ClearForShutDown(); + // Call all the callbacks immediately as the device is about to shut down. + auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); + for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) { + callbackTask->HandleShutDown(); + } } // Disconnect the device, depending on which state we are currently in. @@ -188,7 +192,7 @@ mState = State::Disconnected; mDynamicUploader = nullptr; - mCreatePipelineAsyncTracker = nullptr; + mCallbackTaskManager = nullptr; mPersistentCache = nullptr; mEmptyBindGroupLayout = nullptr; @@ -238,7 +242,10 @@ } mQueue->HandleDeviceLoss(); - mCreatePipelineAsyncTracker->ClearForDeviceLoss(); + auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); + for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) { + callbackTask->HandleDeviceLoss(); + } // Still forward device loss errors to the error scopes so they all reject. mErrorScopeStack->HandleError(ToWGPUErrorType(type), message); @@ -766,10 +773,10 @@ } Ref<RenderPipelineBase> result = maybeResult.AcquireSuccess(); - std::unique_ptr<CreateRenderPipelineAsyncTask> request = - std::make_unique<CreateRenderPipelineAsyncTask>(std::move(result), "", callback, - userdata); - mCreatePipelineAsyncTracker->TrackTask(std::move(request), GetPendingCommandSerial()); + std::unique_ptr<CreateRenderPipelineAsyncCallbackTask> callbackTask = + std::make_unique<CreateRenderPipelineAsyncCallbackTask>(std::move(result), "", callback, + userdata); + mCallbackTaskManager->AddCallbackTask(std::move(callbackTask)); } RenderBundleEncoder* DeviceBase::APICreateRenderBundleEncoder( const RenderBundleEncoderDescriptor* descriptor) { @@ -951,8 +958,19 @@ // reclaiming resources one tick earlier. mDynamicUploader->Deallocate(mCompletedSerial); mQueue->Tick(mCompletedSerial); + } - mCreatePipelineAsyncTracker->Tick(mCompletedSerial); + // We have to check mCallbackTaskManager in every Tick because it is not related to any + // global serials. + if (!mCallbackTaskManager->IsEmpty()) { + // If a user calls Queue::Submit inside the callback, then the device will be ticked, + // which in turns ticks the tracker, causing reentrance and dead lock here. To prevent + // such reentrant call, we remove all the callback tasks from mCallbackTaskManager, + // update mCallbackTaskManager, then call all the callbacks. + auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); + for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) { + callbackTask->Finish(); + } } return {}; @@ -1158,10 +1176,10 @@ result = AddOrGetCachedPipeline(resultOrError.AcquireSuccess(), blueprintHash); } - std::unique_ptr<CreateComputePipelineAsyncTask> request = - std::make_unique<CreateComputePipelineAsyncTask>(result, errorMessage, callback, - userdata); - mCreatePipelineAsyncTracker->TrackTask(std::move(request), GetPendingCommandSerial()); + std::unique_ptr<CreateComputePipelineAsyncCallbackTask> callbackTask = + std::make_unique<CreateComputePipelineAsyncCallbackTask>( + std::move(result), errorMessage, callback, userdata); + mCallbackTaskManager->AddCallbackTask(std::move(callbackTask)); } ResultOrError<Ref<PipelineLayoutBase>> DeviceBase::CreatePipelineLayout(
diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h index 87d567f..134bff8 100644 --- a/src/dawn_native/Device.h +++ b/src/dawn_native/Device.h
@@ -34,7 +34,7 @@ class AttachmentState; class AttachmentStateBlueprint; class BindGroupLayoutBase; - class CreatePipelineAsyncTracker; + class CallbackTaskManager; class DynamicUploader; class ErrorScopeStack; class ExternalTextureBase; @@ -402,7 +402,7 @@ Ref<BindGroupLayoutBase> mEmptyBindGroupLayout; std::unique_ptr<DynamicUploader> mDynamicUploader; - std::unique_ptr<CreatePipelineAsyncTracker> mCreatePipelineAsyncTracker; + std::unique_ptr<CallbackTaskManager> mCallbackTaskManager; Ref<QueueBase> mQueue; struct DeprecationWarnings;