Fix use-after-free issue in Create*PipelineAsyncTasks::Run()
This patch fixes a use-after-free issue in Create*PipelineAsyncTasks
that when pipeline->Initialize() returns error, the pipeline object
will be deleted, while we still attempt to call its member function
after it is deleted.
BUG=dawn:1310
TEST=dawn_unittests
Change-Id: I57d5ca98d6c97b14df1d7c3bf2941c9cc87adeff
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/81800
Reviewed-by: Loko Kung <lokokung@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/src/dawn/native/CreatePipelineAsyncTask.cpp b/src/dawn/native/CreatePipelineAsyncTask.cpp
index ed7764e..92a3bdf 100644
--- a/src/dawn/native/CreatePipelineAsyncTask.cpp
+++ b/src/dawn/native/CreatePipelineAsyncTask.cpp
@@ -114,11 +114,13 @@
void CreateComputePipelineAsyncTask::Run() {
const char* eventLabel = utils::GetLabelForTrace(mComputePipeline->GetLabel().c_str());
- TRACE_EVENT_FLOW_END1(mComputePipeline->GetDevice()->GetPlatform(), General,
+
+ DeviceBase* device = mComputePipeline->GetDevice();
+ TRACE_EVENT_FLOW_END1(device->GetPlatform(), General,
"CreateComputePipelineAsyncTask::RunAsync", this, "label",
eventLabel);
- TRACE_EVENT1(mComputePipeline->GetDevice()->GetPlatform(), General,
- "CreateComputePipelineAsyncTask::Run", "label", eventLabel);
+ TRACE_EVENT1(device->GetPlatform(), General, "CreateComputePipelineAsyncTask::Run", "label",
+ eventLabel);
MaybeError maybeError = mComputePipeline->Initialize();
std::string errorMessage;
@@ -127,8 +129,8 @@
errorMessage = maybeError.AcquireError()->GetMessage();
}
- mComputePipeline->GetDevice()->AddComputePipelineAsyncCallbackTask(
- mComputePipeline, errorMessage, mCallback, mUserdata);
+ device->AddComputePipelineAsyncCallbackTask(mComputePipeline, errorMessage, mCallback,
+ mUserdata);
}
void CreateComputePipelineAsyncTask::RunAsync(
@@ -164,10 +166,12 @@
void CreateRenderPipelineAsyncTask::Run() {
const char* eventLabel = utils::GetLabelForTrace(mRenderPipeline->GetLabel().c_str());
- TRACE_EVENT_FLOW_END1(mRenderPipeline->GetDevice()->GetPlatform(), General,
+
+ DeviceBase* device = mRenderPipeline->GetDevice();
+ TRACE_EVENT_FLOW_END1(device->GetPlatform(), General,
"CreateRenderPipelineAsyncTask::RunAsync", this, "label", eventLabel);
- TRACE_EVENT1(mRenderPipeline->GetDevice()->GetPlatform(), General,
- "CreateRenderPipelineAsyncTask::Run", "label", eventLabel);
+ TRACE_EVENT1(device->GetPlatform(), General, "CreateRenderPipelineAsyncTask::Run", "label",
+ eventLabel);
MaybeError maybeError = mRenderPipeline->Initialize();
std::string errorMessage;
@@ -176,8 +180,8 @@
errorMessage = maybeError.AcquireError()->GetMessage();
}
- mRenderPipeline->GetDevice()->AddRenderPipelineAsyncCallbackTask(
- mRenderPipeline, errorMessage, mCallback, mUserdata);
+ device->AddRenderPipelineAsyncCallbackTask(mRenderPipeline, errorMessage, mCallback,
+ mUserdata);
}
void CreateRenderPipelineAsyncTask::RunAsync(
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index cad12e2..01087b8 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -239,6 +239,7 @@
"unittests/ToBackendTests.cpp",
"unittests/TypedIntegerTests.cpp",
"unittests/native/CommandBufferEncodingTests.cpp",
+ "unittests/native/CreatePipelineAsyncTaskTests.cpp",
"unittests/native/DestroyObjectTests.cpp",
"unittests/native/DeviceCreationTests.cpp",
"unittests/validation/BindGroupValidationTests.cpp",
diff --git a/src/dawn/tests/unittests/native/CreatePipelineAsyncTaskTests.cpp b/src/dawn/tests/unittests/native/CreatePipelineAsyncTaskTests.cpp
new file mode 100644
index 0000000..3583b79
--- /dev/null
+++ b/src/dawn/tests/unittests/native/CreatePipelineAsyncTaskTests.cpp
@@ -0,0 +1,73 @@
+// Copyright 2022 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/tests/DawnNativeTest.h"
+
+#include "dawn/native/CreatePipelineAsyncTask.h"
+#include "mocks/ComputePipelineMock.h"
+#include "mocks/RenderPipelineMock.h"
+
+class CreatePipelineAsyncTaskTests : public DawnNativeTest {};
+
+// A regression test for a null pointer issue in CreateRenderPipelineAsyncTask::Run().
+// See crbug.com/dawn/1310 for more details.
+TEST_F(CreatePipelineAsyncTaskTests, InitializationErrorInCreateRenderPipelineAsync) {
+ dawn::native::DeviceBase* deviceBase =
+ reinterpret_cast<dawn::native::DeviceBase*>(device.Get());
+ Ref<dawn::native::RenderPipelineMock> renderPipelineMock =
+ AcquireRef(new dawn::native::RenderPipelineMock(deviceBase));
+
+ ON_CALL(*renderPipelineMock.Get(), Initialize)
+ .WillByDefault(testing::Return(testing::ByMove(
+ DAWN_MAKE_ERROR(dawn::native::InternalErrorType::Validation, "Initialization Error"))));
+
+ dawn::native::CreateRenderPipelineAsyncTask asyncTask(
+ renderPipelineMock,
+ [](WGPUCreatePipelineAsyncStatus status, WGPURenderPipeline returnPipeline,
+ const char* message, void* userdata) {
+ EXPECT_EQ(WGPUCreatePipelineAsyncStatus::WGPUCreatePipelineAsyncStatus_Error, status);
+ },
+ nullptr);
+
+ asyncTask.Run();
+ device.Tick();
+
+ EXPECT_CALL(*renderPipelineMock.Get(), DestroyImpl).Times(1);
+}
+
+// A regression test for a null pointer issue in CreateComputePipelineAsyncTask::Run().
+// See crbug.com/dawn/1310 for more details.
+TEST_F(CreatePipelineAsyncTaskTests, InitializationErrorInCreateComputePipelineAsync) {
+ dawn::native::DeviceBase* deviceBase =
+ reinterpret_cast<dawn::native::DeviceBase*>(device.Get());
+ Ref<dawn::native::ComputePipelineMock> computePipelineMock =
+ AcquireRef(new dawn::native::ComputePipelineMock(deviceBase));
+
+ ON_CALL(*computePipelineMock.Get(), Initialize)
+ .WillByDefault(testing::Return(testing::ByMove(
+ DAWN_MAKE_ERROR(dawn::native::InternalErrorType::Validation, "Initialization Error"))));
+
+ dawn::native::CreateComputePipelineAsyncTask asyncTask(
+ computePipelineMock,
+ [](WGPUCreatePipelineAsyncStatus status, WGPUComputePipeline returnPipeline,
+ const char* message, void* userdata) {
+ EXPECT_EQ(WGPUCreatePipelineAsyncStatus::WGPUCreatePipelineAsyncStatus_Error, status);
+ },
+ nullptr);
+
+ asyncTask.Run();
+ device.Tick();
+
+ EXPECT_CALL(*computePipelineMock.Get(), DestroyImpl).Times(1);
+}