[dawn][test] Update all device callbacks in tests.
- Modernizes all the unit tests to use mock callbacks for device
lost and uncaptured error.
- Removes any remaining references to device.SetDeviceLostCallback
and device.SetUncapturedErrorCallback in unit tests.
Bug: 42241415
Change-Id: Ide4a803fe21bdbe82683b8704b801635027965fb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/216034
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Shrek Shao <shrekshao@google.com>
diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn
index 2e26a8f..1eed546 100644
--- a/src/dawn/tests/BUILD.gn
+++ b/src/dawn/tests/BUILD.gn
@@ -236,6 +236,7 @@
deps = [
":gmock_and_gtest",
+ ":test_infra_sources",
"${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/native:sources",
"${dawn_root}/src/dawn/native:static",
@@ -286,6 +287,7 @@
":native_mocks_sources",
":partition_alloc_support",
":platform_mocks_sources",
+ ":test_infra_sources",
"${dawn_root}/include/dawn:cpp_headers",
"${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/common",
@@ -314,11 +316,6 @@
"${dawn_root}/src/dawn/wire/server/ServerMemoryTransferService_mock.h",
"DawnNativeTest.cpp",
"DawnNativeTest.h",
- "MockCallback.h",
- "ParamGenerator.h",
- "StringViewMatchers.h",
- "ToggleParser.cpp",
- "ToggleParser.h",
"unittests/AsyncTaskTests.cpp",
"unittests/BitSetIteratorTests.cpp",
"unittests/BuddyAllocatorTests.cpp",
@@ -515,6 +512,14 @@
"${dawn_root}/src/dawn/wire",
]
+ public = [
+ "AdapterTestConfig.h",
+ "DawnTest.h",
+ "MockCallback.h",
+ "ParamGenerator.h",
+ "StringViewMatchers.h",
+ "ToggleParser.h",
+ ]
public_deps = [
":gmock_and_gtest",
"${dawn_root}/src/dawn/partition_alloc:raw_ptr",
@@ -526,14 +531,8 @@
sources = [
"AdapterTestConfig.cpp",
- "AdapterTestConfig.h",
"DawnTest.cpp",
- "DawnTest.h",
- "MockCallback.h",
- "ParamGenerator.h",
- "StringViewMatchers.h",
"ToggleParser.cpp",
- "ToggleParser.h",
]
}
diff --git a/src/dawn/tests/DawnTest.h b/src/dawn/tests/DawnTest.h
index 9c78dfd..2cf9fec 100644
--- a/src/dawn/tests/DawnTest.h
+++ b/src/dawn/tests/DawnTest.h
@@ -343,11 +343,9 @@
// Mock callbacks tracking errors and destruction. These are strict mocks because any errors or
// device loss that aren't expected should result in test failures and not just some warnings
// printed to stdout.
- testing::StrictMock<
- testing::MockCppCallback<void (*)(const wgpu::Device&, wgpu::ErrorType, wgpu::StringView)>>
+ testing::StrictMock<testing::MockCppCallback<wgpu::UncapturedErrorCallback<void>*>>
mDeviceErrorCallback;
- testing::StrictMock<testing::MockCppCallback<
- void (*)(const wgpu::Device&, wgpu::DeviceLostReason, wgpu::StringView)>>
+ testing::StrictMock<testing::MockCppCallback<wgpu::DeviceLostCallback2<void>*>>
mDeviceLostCallback;
// Helper methods to implement the EXPECT_ macros
diff --git a/src/dawn/tests/unittests/native/AllowedErrorTests.cpp b/src/dawn/tests/unittests/native/AllowedErrorTests.cpp
index aca5cf2..dbbc373 100644
--- a/src/dawn/tests/unittests/native/AllowedErrorTests.cpp
+++ b/src/dawn/tests/unittests/native/AllowedErrorTests.cpp
@@ -58,10 +58,10 @@
using ::testing::StrictMock;
using ::testing::Test;
-using MockComputePipelineAsyncCallback = MockCppCallback<
- void (*)(wgpu::CreatePipelineAsyncStatus, wgpu::ComputePipeline, wgpu::StringView)>;
-using MockRenderPipelineAsyncCallback = MockCppCallback<
- void (*)(wgpu::CreatePipelineAsyncStatus, wgpu::RenderPipeline, wgpu::StringView)>;
+using MockComputePipelineAsyncCallback =
+ MockCppCallback<wgpu::CreateComputePipelineAsyncCallback2<void>*>;
+using MockRenderPipelineAsyncCallback =
+ MockCppCallback<wgpu::CreateRenderPipelineAsyncCallback2<void>*>;
static constexpr char kOomErrorMessage[] = "Out of memory error";
static constexpr char kInternalErrorMessage[] = "Internal error";
@@ -76,22 +76,7 @@
}
)";
-class AllowedErrorTests : public DawnMockTest {
- public:
- ~AllowedErrorTests() override { DropDevice(); }
-
- protected:
- void SetUp() override {
- DawnMockTest::SetUp();
- device.SetDeviceLostCallback(mDeviceLostCb.Callback(), mDeviceLostCb.MakeUserdata(this));
- device.SetUncapturedErrorCallback(mDeviceErrorCb.Callback(),
- mDeviceErrorCb.MakeUserdata(this));
- }
-
- // Device mock callbacks used throughout the tests.
- StrictMock<MockCallback<wgpu::DeviceLostCallback>> mDeviceLostCb;
- StrictMock<MockCallback<wgpu::ErrorCallback>> mDeviceErrorCb;
-};
+class AllowedErrorTests : public DawnMockTest {};
//
// Exercise APIs where OOM errors cause a device lost.
@@ -102,8 +87,8 @@
.WillOnce(Return(ByMove(DAWN_OUT_OF_MEMORY_ERROR(kOomErrorMessage))));
// Expect the device lost because of the error.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Unknown,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
device.GetQueue().Submit(0, nullptr);
@@ -120,8 +105,8 @@
.WillOnce(Return(ByMove(DAWN_OUT_OF_MEMORY_ERROR(kOomErrorMessage))));
// Expect the device lost because of the error.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Unknown,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
constexpr uint8_t data = 8;
@@ -141,8 +126,8 @@
.WillOnce(Return(ByMove(DAWN_OUT_OF_MEMORY_ERROR(kOomErrorMessage))));
// Expect the device lost because of the error.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Unknown,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
constexpr uint8_t data[] = {1, 2, 4, 8};
@@ -174,8 +159,8 @@
.WillOnce(Return(ByMove(DAWN_OUT_OF_MEMORY_ERROR(kOomErrorMessage))));
// Expect the device lost because of the error.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Unknown,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
device.GetQueue().CopyTextureForBrowser(&src, &dst, &size, &options);
}
@@ -216,8 +201,8 @@
.WillOnce(Return(ByMove(DAWN_OUT_OF_MEMORY_ERROR(kOomErrorMessage))));
// Expect the device lost because of the error.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Unknown,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
device.GetQueue().CopyExternalTextureForBrowser(&src, &dst, &size, &options);
}
@@ -236,8 +221,8 @@
.WillOnce(Return(ByMove(std::move(computePipelineMock))));
// Expect the device lost because of the error.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Unknown,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
device.CreateComputePipeline(ToCppAPI(&desc));
}
@@ -262,8 +247,8 @@
.WillOnce(Return(ByMove(std::move(renderPipelineMock))));
// Expect the device lost because of the error.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Unknown,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
device.CreateRenderPipeline(ToCppAPI(&desc));
}
@@ -283,13 +268,10 @@
.WillOnce(Return(ByMove(std::move(computePipelineMock))));
// Expect the internal error.
- EXPECT_CALL(mDeviceErrorCb, Call(WGPUErrorType_Internal,
- SizedStringMatches(HasSubstr(kInternalErrorMessage)), this))
+ EXPECT_CALL(mDeviceErrorCallback, Call(CHandleIs(device.Get()), wgpu::ErrorType::Internal,
+ SizedStringMatches(HasSubstr(kInternalErrorMessage))))
.Times(1);
device.CreateComputePipeline(ToCppAPI(&desc));
-
- // Device lost should only happen due to destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
// Internal error from synchronously initializing a render pipeline should not result in a device
@@ -313,13 +295,10 @@
.WillOnce(Return(ByMove(std::move(renderPipelineMock))));
// Expect the internal error.
- EXPECT_CALL(mDeviceErrorCb, Call(WGPUErrorType_Internal,
- SizedStringMatches(HasSubstr(kInternalErrorMessage)), this))
+ EXPECT_CALL(mDeviceErrorCallback, Call(CHandleIs(device.Get()), wgpu::ErrorType::Internal,
+ SizedStringMatches(HasSubstr(kInternalErrorMessage))))
.Times(1);
device.CreateRenderPipeline(ToCppAPI(&desc));
-
- // Device lost should only happen due to destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
//
@@ -347,9 +326,6 @@
device.CreateComputePipelineAsync(ToCppAPI(&desc), wgpu::CallbackMode::AllowProcessEvents,
cb.Callback());
ProcessEvents();
-
- // Device lost should only happen because of destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
// OOM error from asynchronously initializing a render pipeline should not result in a device loss.
@@ -379,9 +355,6 @@
device.CreateRenderPipelineAsync(ToCppAPI(&desc), wgpu::CallbackMode::AllowProcessEvents,
cb.Callback());
ProcessEvents();
-
- // Device lost should only happen because of destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
// Internal error from asynchronously initializing a compute pipeline should not result in a device
@@ -406,9 +379,6 @@
device.CreateComputePipelineAsync(ToCppAPI(&desc), wgpu::CallbackMode::AllowProcessEvents,
cb.Callback());
ProcessEvents();
-
- // Device lost should only happen because of destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
// Internal error from asynchronously initializing a render pipeline should not result in a device
@@ -439,9 +409,6 @@
device.CreateRenderPipelineAsync(ToCppAPI(&desc), wgpu::CallbackMode::AllowProcessEvents,
cb.Callback());
ProcessEvents();
-
- // Device lost should only happen because of destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
//
@@ -454,17 +421,14 @@
.WillOnce(Return(ByMove(DAWN_OUT_OF_MEMORY_ERROR(kOomErrorMessage))));
// Expect the OOM error.
- EXPECT_CALL(mDeviceErrorCb, Call(WGPUErrorType_OutOfMemory,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceErrorCallback, Call(CHandleIs(device.Get()), wgpu::ErrorType::OutOfMemory,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
wgpu::BufferDescriptor desc = {};
desc.usage = wgpu::BufferUsage::Uniform;
desc.size = 16;
device.CreateBuffer(&desc);
-
- // Device lost should only happen because of destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
// OOM error from texture creation is allowed and surfaced directly.
@@ -473,8 +437,8 @@
.WillOnce(Return(ByMove(DAWN_OUT_OF_MEMORY_ERROR(kOomErrorMessage))));
// Expect the OOM error.
- EXPECT_CALL(mDeviceErrorCb, Call(WGPUErrorType_OutOfMemory,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceErrorCallback, Call(CHandleIs(device.Get()), wgpu::ErrorType::OutOfMemory,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
wgpu::TextureDescriptor desc = {};
@@ -482,9 +446,6 @@
desc.size = {4, 4};
desc.format = wgpu::TextureFormat::RGBA8Unorm;
device.CreateTexture(&desc);
-
- // Device lost should only happen because of destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
// OOM error from query set creation is allowed and surfaced directly.
@@ -493,29 +454,23 @@
.WillOnce(Return(ByMove(DAWN_OUT_OF_MEMORY_ERROR(kOomErrorMessage))));
// Expect the OOM error.
- EXPECT_CALL(mDeviceErrorCb, Call(WGPUErrorType_OutOfMemory,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceErrorCallback, Call(CHandleIs(device.Get()), wgpu::ErrorType::OutOfMemory,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
wgpu::QuerySetDescriptor desc = {};
desc.type = wgpu::QueryType::Occlusion;
desc.count = 1;
device.CreateQuerySet(&desc);
-
- // Device lost should only happen because of destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
TEST_F(AllowedErrorTests, InjectError) {
// Expect the OOM error.
- EXPECT_CALL(mDeviceErrorCb, Call(WGPUErrorType_OutOfMemory,
- SizedStringMatches(HasSubstr(kOomErrorMessage)), this))
+ EXPECT_CALL(mDeviceErrorCallback, Call(CHandleIs(device.Get()), wgpu::ErrorType::OutOfMemory,
+ SizedStringMatches(HasSubstr(kOomErrorMessage))))
.Times(1);
device.InjectError(wgpu::ErrorType::OutOfMemory, kOomErrorMessage);
-
- // Device lost should only happen because of destruction.
- EXPECT_CALL(mDeviceLostCb, Call(WGPUDeviceLostReason_Destroyed, _, this)).Times(1);
}
} // anonymous namespace
diff --git a/src/dawn/tests/unittests/native/DestroyObjectTests.cpp b/src/dawn/tests/unittests/native/DestroyObjectTests.cpp
index caed653..c6c1e90 100644
--- a/src/dawn/tests/unittests/native/DestroyObjectTests.cpp
+++ b/src/dawn/tests/unittests/native/DestroyObjectTests.cpp
@@ -1030,7 +1030,12 @@
EXPECT_TRUE(FromAPI(csModule.Get())->IsAlive());
EXPECT_TRUE(FromAPI(texture.Get())->IsAlive());
EXPECT_TRUE(FromAPI(textureView.Get())->IsAlive());
+
+ EXPECT_CALL(mDeviceLostCallback,
+ Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Destroyed, _))
+ .Times(1);
device.Destroy();
+
EXPECT_FALSE(FromAPI(bindGroup.Get())->IsAlive());
EXPECT_FALSE(FromAPI(bindGroupLayout.Get())->IsAlive());
EXPECT_FALSE(FromAPI(buffer.Get())->IsAlive());
diff --git a/src/dawn/tests/unittests/native/mocks/DawnMockTest.cpp b/src/dawn/tests/unittests/native/mocks/DawnMockTest.cpp
index 410f719..8077a73 100644
--- a/src/dawn/tests/unittests/native/mocks/DawnMockTest.cpp
+++ b/src/dawn/tests/unittests/native/mocks/DawnMockTest.cpp
@@ -32,6 +32,9 @@
#include "dawn/dawn_proc.h"
#include "dawn/native/ChainUtils.h"
+using testing::_;
+using testing::AtMost;
+
namespace dawn::native {
DawnMockTest::DawnMockTest() : mDeviceToggles(ToggleStage::Device) {}
@@ -42,15 +45,21 @@
const auto& adapters = instance->EnumerateAdapters();
DAWN_ASSERT(!adapters.empty());
- auto result = ValidateAndUnpack(&mDeviceDescriptor);
- DAWN_ASSERT(result.IsSuccess());
- UnpackedPtr<DeviceDescriptor> packedDeviceDescriptor = result.AcquireSuccess();
+ wgpu::DeviceDescriptor desc = {};
+ desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
+ mDeviceLostCallback.Callback());
+ desc.SetUncapturedErrorCallback(mDeviceErrorCallback.TemplatedCallback(),
+ mDeviceErrorCallback.TemplatedCallbackUserdata());
+ DeviceDescriptor* nativeDesc = reinterpret_cast<DeviceDescriptor*>(&desc);
- Ref<DeviceBase::DeviceLostEvent> lostEvent =
- DeviceBase::DeviceLostEvent::Create(&mDeviceDescriptor);
+ auto result = ValidateAndUnpack(nativeDesc);
+ DAWN_ASSERT(result.IsSuccess());
+ UnpackedPtr<DeviceDescriptor> unpackedDesc = result.AcquireSuccess();
+
+ Ref<DeviceBase::DeviceLostEvent> lostEvent = DeviceBase::DeviceLostEvent::Create(nativeDesc);
auto deviceMock = AcquireRef(new ::testing::NiceMock<DeviceMock>(
- adapters[0].Get(), packedDeviceDescriptor, mDeviceToggles, std::move(lostEvent)));
+ adapters[0].Get(), unpackedDesc, mDeviceToggles, std::move(lostEvent)));
mDeviceMock = deviceMock.Get();
device = wgpu::Device::Acquire(ToAPI(ReturnToAPI<DeviceBase>(std::move(deviceMock))));
}
@@ -60,6 +69,10 @@
return;
}
+ EXPECT_CALL(mDeviceLostCallback,
+ Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Destroyed, _))
+ .Times(AtMost(1));
+
// Since the device owns the instance in these tests, we need to explicitly verify that the
// instance has completed all work. To do this, we take an additional ref to the instance here
// and use it to process events until completion after dropping the device.
diff --git a/src/dawn/tests/unittests/native/mocks/DawnMockTest.h b/src/dawn/tests/unittests/native/mocks/DawnMockTest.h
index c02548c..a6a2014 100644
--- a/src/dawn/tests/unittests/native/mocks/DawnMockTest.h
+++ b/src/dawn/tests/unittests/native/mocks/DawnMockTest.h
@@ -28,11 +28,17 @@
#include <gtest/gtest.h>
#include <webgpu/webgpu_cpp.h>
+#include "dawn/tests/MockCallback.h"
#include "dawn/tests/unittests/native/mocks/DeviceMock.h"
#include "partition_alloc/pointers/raw_ptr.h"
namespace dawn::native {
+// Matcher for C++ types to verify that their internal C-handles are identical.
+MATCHER_P(CHandleIs, cType, "") {
+ return arg.Get() == cType;
+}
+
class DawnMockTest : public ::testing::Test {
public:
DawnMockTest();
@@ -44,7 +50,12 @@
void SetUp() override;
void DropDevice();
- DeviceDescriptor mDeviceDescriptor;
+ // Device mock callbacks used throughout the tests.
+ testing::StrictMock<testing::MockCppCallback<wgpu::UncapturedErrorCallback<void>*>>
+ mDeviceErrorCallback;
+ testing::StrictMock<testing::MockCppCallback<wgpu::DeviceLostCallback2<void>*>>
+ mDeviceLostCallback;
+
TogglesState mDeviceToggles;
raw_ptr<::testing::NiceMock<DeviceMock>> mDeviceMock;
wgpu::Device device;
diff --git a/src/dawn/tests/unittests/wire/WireDisconnectTests.cpp b/src/dawn/tests/unittests/wire/WireDisconnectTests.cpp
index 041fa7f..a0a1fb6 100644
--- a/src/dawn/tests/unittests/wire/WireDisconnectTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireDisconnectTests.cpp
@@ -84,7 +84,7 @@
// Check that disconnecting the wire client calls the device lost callback exacty once.
TEST_F(WireDisconnectTests, CallsDeviceLostCallback) {
// Disconnect the wire client. We should receive device lost only once.
- EXPECT_CALL(deviceLostCallback, Call(_, WGPUDeviceLostReason_InstanceDropped, _, this))
+ EXPECT_CALL(deviceLostCallback, Call(_, wgpu::DeviceLostReason::InstanceDropped, _))
.Times(Exactly(1));
GetWireClient()->Disconnect();
GetWireClient()->Disconnect();
@@ -97,45 +97,34 @@
ToOutputStringView("some reason"));
// Flush the device lost return command.
- EXPECT_CALL(deviceLostCallback,
- Call(_, WGPUDeviceLostReason_Unknown, SizedString("some reason"), this))
+ EXPECT_CALL(deviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedString("some reason")))
.Times(Exactly(1));
FlushServer();
// Disconnect the client. We shouldn't see the lost callback again.
- EXPECT_CALL(deviceLostCallback, Call).Times(Exactly(0));
GetWireClient()->Disconnect();
}
// Check that disconnecting the wire client inside the device loss callback does not trigger the
// callback again.
TEST_F(WireDisconnectTests, ServerLostThenDisconnectInCallback) {
- MockCallback<WGPUDeviceLostCallback> mockDeviceLostCallback;
- device.SetDeviceLostCallback(mockDeviceLostCallback.Callback(),
- mockDeviceLostCallback.MakeUserdata(this));
-
api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, WGPUDeviceLostReason_Unknown,
ToOutputStringView("lost reason"));
// Disconnect the client inside the lost callback. We should see the callback
// only once.
- EXPECT_CALL(mockDeviceLostCallback,
- Call(WGPUDeviceLostReason_Unknown, SizedString("lost reason"), this))
- .WillOnce(InvokeWithoutArgs([&] {
- EXPECT_CALL(mockDeviceLostCallback, Call(_, _, _)).Times(Exactly(0));
- GetWireClient()->Disconnect();
- }));
+ EXPECT_CALL(deviceLostCallback,
+ Call(_, wgpu::DeviceLostReason::Unknown, SizedString("lost reason")))
+ .Times(Exactly(1))
+ .WillOnce(InvokeWithoutArgs([&] { GetWireClient()->Disconnect(); }));
FlushServer();
}
// Check that a device loss after a disconnect does not trigger the callback again.
TEST_F(WireDisconnectTests, DisconnectThenServerLost) {
- MockCallback<WGPUDeviceLostCallback> mockDeviceLostCallback;
- device.SetDeviceLostCallback(mockDeviceLostCallback.Callback(),
- mockDeviceLostCallback.MakeUserdata(this));
-
// Disconnect the client. We should see the callback once.
- EXPECT_CALL(mockDeviceLostCallback, Call(WGPUDeviceLostReason_InstanceDropped, _, this))
+ EXPECT_CALL(deviceLostCallback, Call(_, wgpu::DeviceLostReason::InstanceDropped, _))
.Times(Exactly(1));
GetWireClient()->Disconnect();
@@ -143,7 +132,6 @@
// called again.
api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, WGPUDeviceLostReason_Unknown,
ToOutputStringView("lost reason"));
- EXPECT_CALL(mockDeviceLostCallback, Call(_, _, _)).Times(Exactly(0));
FlushServer();
}
diff --git a/src/dawn/tests/unittests/wire/WireErrorCallbackTests.cpp b/src/dawn/tests/unittests/wire/WireErrorCallbackTests.cpp
index 62e7270..4fc1c3d 100644
--- a/src/dawn/tests/unittests/wire/WireErrorCallbackTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireErrorCallbackTests.cpp
@@ -49,16 +49,6 @@
using testing::StrictMock;
// Mock classes to add expectations on the wire calling callbacks
-class MockDeviceErrorCallback {
- public:
- MOCK_METHOD(void, Call, (WGPUErrorType type, WGPUStringView message, void* userdata));
-};
-
-std::unique_ptr<StrictMock<MockDeviceErrorCallback>> mockDeviceErrorCallback;
-void ToMockDeviceErrorCallback(WGPUErrorType type, WGPUStringView message, void* userdata) {
- mockDeviceErrorCallback->Call(type, message, userdata);
-}
-
class MockDeviceLoggingCallback {
public:
MOCK_METHOD(void, Call, (WGPULoggingType type, WGPUStringView message, void* userdata));
@@ -69,16 +59,6 @@
mockDeviceLoggingCallback->Call(type, message, userdata);
}
-class MockDeviceLostCallback {
- public:
- MOCK_METHOD(void, Call, (WGPUDeviceLostReason reason, WGPUStringView message, void* userdata));
-};
-
-std::unique_ptr<StrictMock<MockDeviceLostCallback>> mockDeviceLostCallback;
-void ToMockDeviceLostCallback(WGPUDeviceLostReason reason, WGPUStringView message, void* userdata) {
- mockDeviceLostCallback->Call(reason, message, userdata);
-}
-
class WireErrorCallbackTests : public WireTest {
public:
WireErrorCallbackTests() {}
@@ -86,84 +66,15 @@
void SetUp() override {
WireTest::SetUp();
-
- mockDeviceErrorCallback = std::make_unique<StrictMock<MockDeviceErrorCallback>>();
mockDeviceLoggingCallback = std::make_unique<StrictMock<MockDeviceLoggingCallback>>();
- mockDeviceLostCallback = std::make_unique<StrictMock<MockDeviceLostCallback>>();
}
void TearDown() override {
WireTest::TearDown();
-
- mockDeviceErrorCallback = nullptr;
mockDeviceLoggingCallback = nullptr;
- mockDeviceLostCallback = nullptr;
- }
-
- void FlushServer() {
- WireTest::FlushServer();
-
- Mock::VerifyAndClearExpectations(&mockDeviceErrorCallback);
}
};
-// Test the return wire for device validation error callbacks
-TEST_F(WireErrorCallbackTests, DeviceValidationErrorCallback) {
- device.SetUncapturedErrorCallback(ToMockDeviceErrorCallback, this);
-
- // Setting the error callback should stay on the client side and do nothing
- FlushClient();
-
- // Calling the callback on the server side will result in the callback being called on the
- // client side
- api.CallDeviceSetUncapturedErrorCallbackCallback(apiDevice, WGPUErrorType_Validation,
- ToOutputStringView("Some error message"));
-
- EXPECT_CALL(*mockDeviceErrorCallback,
- Call(WGPUErrorType_Validation, SizedString("Some error message"), this))
- .Times(1);
-
- FlushServer();
-}
-
-// Test the return wire for device OOM error callbacks
-TEST_F(WireErrorCallbackTests, DeviceOutOfMemoryErrorCallback) {
- device.SetUncapturedErrorCallback(ToMockDeviceErrorCallback, this);
-
- // Setting the error callback should stay on the client side and do nothing
- FlushClient();
-
- // Calling the callback on the server side will result in the callback being called on the
- // client side
- api.CallDeviceSetUncapturedErrorCallbackCallback(apiDevice, WGPUErrorType_OutOfMemory,
- ToOutputStringView("Some error message"));
-
- EXPECT_CALL(*mockDeviceErrorCallback,
- Call(WGPUErrorType_OutOfMemory, SizedString("Some error message"), this))
- .Times(1);
-
- FlushServer();
-}
-
-// Test the return wire for device internal error callbacks
-TEST_F(WireErrorCallbackTests, DeviceInternalErrorCallback) {
- device.SetUncapturedErrorCallback(ToMockDeviceErrorCallback, this);
-
- // Setting the error callback should stay on the client side and do nothing
- FlushClient();
-
- // Calling the callback on the server side will result in the callback being called on the
- // client side
- api.CallDeviceSetUncapturedErrorCallbackCallback(apiDevice, WGPUErrorType_Internal,
- ToOutputStringView("Some error message"));
-
- EXPECT_CALL(*mockDeviceErrorCallback,
- Call(WGPUErrorType_Internal, SizedString("Some error message"), this))
- .Times(1);
-
- FlushServer();
-}
-
// Test the return wire for device user warning callbacks
TEST_F(WireErrorCallbackTests, DeviceLoggingCallback) {
device.SetLoggingCallback(ToMockDeviceLoggingCallback, this);
@@ -183,20 +94,33 @@
FlushServer();
}
-// Test the return wire for device lost callback
+// Test the return wire for device error callbacks.
+TEST_F(WireErrorCallbackTests, DeviceErrorCallbacks) {
+ static constexpr std::array<wgpu::ErrorType, 3> kErrorTypes = {
+ wgpu::ErrorType::Validation, wgpu::ErrorType::OutOfMemory, wgpu::ErrorType::Internal};
+
+ for (auto type : kErrorTypes) {
+ // Calling the callback on the server side will result in the callback being called on the
+ // client side when the server is flushed.
+ api.CallDeviceSetUncapturedErrorCallbackCallback(
+ apiDevice, static_cast<WGPUErrorType>(type), ToOutputStringView("Some error message"));
+ EXPECT_CALL(uncapturedErrorCallback,
+ Call(CHandleIs(device.Get()), type, SizedString("Some error message")))
+ .Times(1);
+ }
+
+ FlushServer();
+}
+
+// Test the return wire for device lost callback.
TEST_F(WireErrorCallbackTests, DeviceLostCallback) {
- wgpuDeviceSetDeviceLostCallback(cDevice, ToMockDeviceLostCallback, this);
-
- // Setting the error callback should stay on the client side and do nothing
- FlushClient();
-
// Calling the callback on the server side will result in the callback being called on the
- // client side
+ // client side when the server is flushed.
api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, WGPUDeviceLostReason_Unknown,
ToOutputStringView("Some error message"));
- EXPECT_CALL(*mockDeviceLostCallback,
- Call(WGPUDeviceLostReason_Unknown, SizedString("Some error message"), this))
+ EXPECT_CALL(deviceLostCallback, Call(CHandleIs(device.Get()), wgpu::DeviceLostReason::Unknown,
+ SizedString("Some error message")))
.Times(1);
FlushServer();
diff --git a/src/dawn/tests/unittests/wire/WireTest.cpp b/src/dawn/tests/unittests/wire/WireTest.cpp
index c498831..95fc06e 100644
--- a/src/dawn/tests/unittests/wire/WireTest.cpp
+++ b/src/dawn/tests/unittests/wire/WireTest.cpp
@@ -133,11 +133,10 @@
// Create the device for testing.
apiDevice = api.GetNewDevice();
wgpu::DeviceDescriptor deviceDesc = {};
- deviceDesc.deviceLostCallbackInfo = {nullptr, wgpu::CallbackMode::AllowSpontaneous,
- deviceLostCallback.Callback(),
- deviceLostCallback.MakeUserdata(this)};
- deviceDesc.uncapturedErrorCallbackInfo = {nullptr, uncapturedErrorCallback.Callback(),
- uncapturedErrorCallback.MakeUserdata(this)};
+ deviceDesc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
+ deviceLostCallback.Callback());
+ deviceDesc.SetUncapturedErrorCallback(uncapturedErrorCallback.TemplatedCallback(),
+ uncapturedErrorCallback.TemplatedCallbackUserdata());
EXPECT_CALL(deviceLostCallback, Call).Times(AtMost(1));
MockCallback<void (*)(wgpu::RequestDeviceStatus, wgpu::Device, wgpu::StringView, void*)>
diff --git a/src/dawn/tests/unittests/wire/WireTest.h b/src/dawn/tests/unittests/wire/WireTest.h
index 0acb7fb..f809a6c 100644
--- a/src/dawn/tests/unittests/wire/WireTest.h
+++ b/src/dawn/tests/unittests/wire/WireTest.h
@@ -111,6 +111,11 @@
return MakeMatcher(new StringMessageMatcher());
}
+// Matcher for C++ types to verify that their internal C-handles are identical.
+MATCHER_P(CHandleIs, cType, "") {
+ return arg.Get() == cType;
+}
+
// Skip a test when the given condition is satisfied.
#define DAWN_SKIP_TEST_IF(condition) \
do { \
@@ -152,8 +157,13 @@
testing::StrictMock<MockProcTable> api;
- testing::MockCallback<WGPUDeviceLostCallbackNew> deviceLostCallback;
- testing::MockCallback<WGPUErrorCallback> uncapturedErrorCallback;
+ // Mock callbacks tracking errors and destruction. These are strict mocks because any errors or
+ // device loss that aren't expected should result in test failures and not just some warnings
+ // printed to stdout.
+ testing::StrictMock<testing::MockCppCallback<wgpu::DeviceLostCallback2<void>*>>
+ deviceLostCallback;
+ testing::StrictMock<testing::MockCppCallback<wgpu::UncapturedErrorCallback<void>*>>
+ uncapturedErrorCallback;
wgpu::Instance instance;
WGPUInstance apiInstance;