dawn_wire: Add Reserve/InjectDevice
Now that the wire does enough tracking to prevent a malicious client
from freeing a device before its child objects, and the device is no
longer a "special" object with regard to reference/release, it is
safe to support multiple devices on the wire. The simplest way to
use this in WebGPU (to fix createReadyRenderPipeline validation)
is to add a reserve/inject device API similar to the one we use for
swapchain textures.
Bug: dawn:565
Change-Id: Ie956aff528c5610c9ecc5c189dab2d22185cb572
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/37800
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/generator/templates/dawn_wire/server/ServerDoers.cpp b/generator/templates/dawn_wire/server/ServerDoers.cpp
index 0c6ce42..0252336 100644
--- a/generator/templates/dawn_wire/server/ServerDoers.cpp
+++ b/generator/templates/dawn_wire/server/ServerDoers.cpp
@@ -98,6 +98,11 @@
*data->childObjectTypesAndIds.begin());
DoDestroyObject(childObjectType, childObjectId);
}
+ if (data->handle != nullptr) {
+ //* Deregisters uncaptured error and device lost callbacks since
+ //* they should not be forwarded if the device no longer exists on the wire.
+ ClearDeviceCallbacks(data->handle);
+ }
{% endif %}
if (data->handle != nullptr) {
mProcs.{{as_varName(type.name, Name("release"))}}(data->handle);
diff --git a/src/dawn_wire/WireClient.cpp b/src/dawn_wire/WireClient.cpp
index de73a84..0dcea37 100644
--- a/src/dawn_wire/WireClient.cpp
+++ b/src/dawn_wire/WireClient.cpp
@@ -37,6 +37,10 @@
return mImpl->ReserveTexture(device);
}
+ ReservedDevice WireClient::ReserveDevice() {
+ return mImpl->ReserveDevice();
+ }
+
void WireClient::Disconnect() {
mImpl->Disconnect();
}
diff --git a/src/dawn_wire/WireServer.cpp b/src/dawn_wire/WireServer.cpp
index 723f691..763b5fc 100644
--- a/src/dawn_wire/WireServer.cpp
+++ b/src/dawn_wire/WireServer.cpp
@@ -40,6 +40,14 @@
return mImpl->InjectTexture(texture, id, generation, deviceId, deviceGeneration);
}
+ bool WireServer::InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation) {
+ return mImpl->InjectDevice(device, id, generation);
+ }
+
+ WGPUDevice WireServer::GetDevice(uint32_t id, uint32_t generation) {
+ return mImpl->GetDevice(id, generation);
+ }
+
namespace server {
MemoryTransferService::MemoryTransferService() = default;
diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp
index 0ca5f61..b566529 100644
--- a/src/dawn_wire/client/Client.cpp
+++ b/src/dawn_wire/client/Client.cpp
@@ -85,8 +85,13 @@
}
WGPUDevice Client::GetDevice() {
+ // This function is deprecated. The concept of a "default" device on the wire
+ // will be removed in favor of ReserveDevice/InjectDevice.
if (mDevice == nullptr) {
- mDevice = DeviceAllocator().New(this)->object.get();
+ ReservedDevice reservation = ReserveDevice();
+ mDevice = FromAPI(reservation.device);
+ ASSERT(reservation.id == 1);
+ ASSERT(reservation.generation == 0);
}
return reinterpret_cast<WGPUDeviceImpl*>(mDevice);
}
@@ -103,6 +108,16 @@
return result;
}
+ ReservedDevice Client::ReserveDevice() {
+ auto* allocation = DeviceAllocator().New(this);
+
+ ReservedDevice result;
+ result.device = ToAPI(allocation->object.get());
+ result.id = allocation->object->id;
+ result.generation = allocation->generation;
+ return result;
+ }
+
void Client::Disconnect() {
mDisconnected = true;
mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance());
diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h
index 4902df8..dd7ac76 100644
--- a/src/dawn_wire/client/Client.h
+++ b/src/dawn_wire/client/Client.h
@@ -46,6 +46,7 @@
}
ReservedTexture ReserveTexture(WGPUDevice device);
+ ReservedDevice ReserveDevice();
template <typename Cmd>
void SerializeCommand(const Cmd& cmd) {
diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp
index b6ee4fc..2d643cb 100644
--- a/src/dawn_wire/client/Device.cpp
+++ b/src/dawn_wire/client/Device.cpp
@@ -45,15 +45,6 @@
}
};
#endif // DAWN_ENABLE_ASSERTS
- // Get the default queue for this device.
- auto* allocation = client->QueueAllocator().New(client);
- mDefaultQueue = allocation->object.get();
-
- DeviceGetDefaultQueueCmd cmd;
- cmd.self = ToAPI(this);
- cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
-
- client->SerializeCommand(cmd);
}
Device::~Device() {
@@ -206,6 +197,22 @@
}
WGPUQueue Device::GetDefaultQueue() {
+ // The queue is lazily created because if a Device is created by
+ // Reserve/Inject, we cannot send the getDefaultQueue message until
+ // it has been injected on the Server. It cannot happen immediately
+ // on construction.
+ if (mDefaultQueue == nullptr) {
+ // Get the default queue for this device.
+ auto* allocation = client->QueueAllocator().New(client);
+ mDefaultQueue = allocation->object.get();
+
+ DeviceGetDefaultQueueCmd cmd;
+ cmd.self = ToAPI(this);
+ cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
+
+ client->SerializeCommand(cmd);
+ }
+
mDefaultQueue->refcount++;
return ToAPI(mDefaultQueue);
}
diff --git a/src/dawn_wire/server/ObjectStorage.h b/src/dawn_wire/server/ObjectStorage.h
index 74cc5a7..c803f53 100644
--- a/src/dawn_wire/server/ObjectStorage.h
+++ b/src/dawn_wire/server/ObjectStorage.h
@@ -160,6 +160,17 @@
return objects;
}
+ std::vector<T> GetAllHandles() {
+ std::vector<T> objects;
+ for (Data& data : mKnown) {
+ if (data.allocated && data.handle != nullptr) {
+ objects.push_back(data.handle);
+ }
+ }
+
+ return objects;
+ }
+
private:
std::vector<Data> mKnown;
};
diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp
index 67a9dd5..39e50ea 100644
--- a/src/dawn_wire/server/Server.cpp
+++ b/src/dawn_wire/server/Server.cpp
@@ -23,7 +23,6 @@
MemoryTransferService* memoryTransferService)
: mSerializer(serializer),
mProcs(procs),
- mDeviceOnCreation(device),
mMemoryTransferService(memoryTransferService),
mIsAlive(std::make_shared<bool>(true)) {
if (mMemoryTransferService == nullptr) {
@@ -31,38 +30,21 @@
mOwnedMemoryTransferService = CreateInlineMemoryTransferService();
mMemoryTransferService = mOwnedMemoryTransferService.get();
}
- // The client-server knowledge is bootstrapped with device 1.
- auto* deviceData = DeviceObjects().Allocate(1);
- deviceData->handle = device;
- // Take an extra ref. All objects may be freed by the client, but this
- // one is externally owned.
- mProcs.deviceReference(device);
-
- // Note: these callbacks are manually inlined here since they do not acquire and
- // free their userdata.
- mProcs.deviceSetUncapturedErrorCallback(
- device,
- [](WGPUErrorType type, const char* message, void* userdata) {
- Server* server = static_cast<Server*>(userdata);
- server->OnUncapturedError(type, message);
- },
- this);
- mProcs.deviceSetDeviceLostCallback(
- device,
- [](const char* message, void* userdata) {
- Server* server = static_cast<Server*>(userdata);
- server->OnDeviceLost(message);
- },
- this);
+ // For the deprecated initialization path:
+ // The client-server knowledge is bootstrapped with device 1, generation 0.
+ if (device != nullptr) {
+ bool success = InjectDevice(device, 1, 0);
+ ASSERT(success);
+ }
}
Server::~Server() {
// Un-set the error and lost callbacks since we cannot forward them
// after the server has been destroyed.
- mProcs.deviceSetUncapturedErrorCallback(mDeviceOnCreation, nullptr, nullptr);
- mProcs.deviceSetDeviceLostCallback(mDeviceOnCreation, nullptr, nullptr);
-
+ for (WGPUDevice device : DeviceObjects().GetAllHandles()) {
+ ClearDeviceCallbacks(device);
+ }
DestroyAllObjects(mProcs);
}
@@ -71,6 +53,7 @@
uint32_t generation,
uint32_t deviceId,
uint32_t deviceGeneration) {
+ ASSERT(texture != nullptr);
ObjectData<WGPUDevice>* device = DeviceObjects().Get(deviceId);
if (device == nullptr || device->generation != deviceGeneration) {
return false;
@@ -97,6 +80,57 @@
return true;
}
+ bool Server::InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation) {
+ ASSERT(device != nullptr);
+ ObjectData<WGPUDevice>* data = DeviceObjects().Allocate(id);
+ if (data == nullptr) {
+ return false;
+ }
+
+ data->handle = device;
+ data->generation = generation;
+ data->allocated = true;
+
+ // The device is externally owned so it shouldn't be destroyed when we receive a destroy
+ // message from the client. Add a reference to counterbalance the eventual release.
+ mProcs.deviceReference(device);
+
+ // Set callbacks to forward errors to the client.
+ // Note: these callbacks are manually inlined here since they do not acquire and
+ // free their userdata.
+ mProcs.deviceSetUncapturedErrorCallback(
+ device,
+ [](WGPUErrorType type, const char* message, void* userdata) {
+ Server* server = static_cast<Server*>(userdata);
+ server->OnUncapturedError(type, message);
+ },
+ this);
+ mProcs.deviceSetDeviceLostCallback(
+ device,
+ [](const char* message, void* userdata) {
+ Server* server = static_cast<Server*>(userdata);
+ server->OnDeviceLost(message);
+ },
+ this);
+
+ return true;
+ }
+
+ WGPUDevice Server::GetDevice(uint32_t id, uint32_t generation) {
+ ObjectData<WGPUDevice>* data = DeviceObjects().Get(id);
+ if (data == nullptr || data->generation != generation) {
+ return nullptr;
+ }
+ return data->handle;
+ }
+
+ void Server::ClearDeviceCallbacks(WGPUDevice device) {
+ // Un-set the error and lost callbacks since we cannot forward them
+ // after the server has been destroyed.
+ mProcs.deviceSetUncapturedErrorCallback(device, nullptr, nullptr);
+ mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr);
+ }
+
bool TrackDeviceChild(ObjectDataBase<WGPUDevice>* device, ObjectType type, ObjectId id) {
auto it = static_cast<ObjectData<WGPUDevice>*>(device)->childObjectTypesAndIds.insert(
PackObjectTypeAndId(type, id));
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index f45ed0d..4056896 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -167,6 +167,10 @@
uint32_t deviceId,
uint32_t deviceGeneration);
+ bool InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation);
+
+ WGPUDevice GetDevice(uint32_t id, uint32_t generation);
+
template <typename T,
typename Enable = std::enable_if<std::is_base_of<CallbackUserdata, T>::value>>
std::unique_ptr<T> MakeUserdata() {
@@ -186,6 +190,7 @@
mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize);
}
+ void ClearDeviceCallbacks(WGPUDevice device);
// Error callbacks
void OnUncapturedError(WGPUErrorType type, const char* message);
@@ -212,7 +217,6 @@
WireDeserializeAllocator mAllocator;
ChunkedCommandSerializer mSerializer;
DawnProcTable mProcs;
- WGPUDevice mDeviceOnCreation;
std::unique_ptr<MemoryTransferService> mOwnedMemoryTransferService = nullptr;
MemoryTransferService* mMemoryTransferService = nullptr;
diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h
index 8af02a97..b8f1247 100644
--- a/src/include/dawn_wire/WireClient.h
+++ b/src/include/dawn_wire/WireClient.h
@@ -38,6 +38,12 @@
uint32_t deviceGeneration;
};
+ struct ReservedDevice {
+ WGPUDevice device;
+ uint32_t id;
+ uint32_t generation;
+ };
+
struct DAWN_WIRE_EXPORT WireClientDescriptor {
CommandSerializer* serializer;
client::MemoryTransferService* memoryTransferService = nullptr;
@@ -53,6 +59,7 @@
size_t size) override final;
ReservedTexture ReserveTexture(WGPUDevice device);
+ ReservedDevice ReserveDevice();
// Disconnects the client.
// Commands allocated after this point will not be sent.
diff --git a/src/include/dawn_wire/WireServer.h b/src/include/dawn_wire/WireServer.h
index ad36f44..9ff6fed 100644
--- a/src/include/dawn_wire/WireServer.h
+++ b/src/include/dawn_wire/WireServer.h
@@ -50,6 +50,17 @@
uint32_t deviceId = 1,
uint32_t deviceGeneration = 0);
+ bool InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation);
+
+ // Look up a device by (id, generation) pair. Returns nullptr if the generation
+ // has expired or the id is not found.
+ // The Wire does not have destroy hooks to allow an embedder to observe when an object
+ // has been destroyed, but in Chrome, we need to know the list of live devices so we
+ // can call device.Tick() on all of them periodically to ensure progress on asynchronous
+ // work is made. Getting this list can be done by tracking the (id, generation) of
+ // previously injected devices, and observing if GetDevice(id, generation) returns non-null.
+ WGPUDevice GetDevice(uint32_t id, uint32_t generation);
+
private:
std::unique_ptr<server::Server> mImpl;
};
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index 5a3e310..e614db2 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -223,6 +223,7 @@
"unittests/wire/WireErrorCallbackTests.cpp",
"unittests/wire/WireExtensionTests.cpp",
"unittests/wire/WireFenceTests.cpp",
+ "unittests/wire/WireInjectDeviceTests.cpp",
"unittests/wire/WireInjectTextureTests.cpp",
"unittests/wire/WireMemoryTransferServiceTests.cpp",
"unittests/wire/WireMultipleDeviceTests.cpp",
diff --git a/src/tests/unittests/wire/WireDestroyObjectTests.cpp b/src/tests/unittests/wire/WireDestroyObjectTests.cpp
index 34b976d..2c7ddc2 100644
--- a/src/tests/unittests/wire/WireDestroyObjectTests.cpp
+++ b/src/tests/unittests/wire/WireDestroyObjectTests.cpp
@@ -36,10 +36,19 @@
// The device and child objects should be released.
EXPECT_CALL(api, CommandEncoderRelease(apiEncoder)).InSequence(s1);
EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2);
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+ .Times(1)
+ .InSequence(s1, s2);
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+ .Times(1)
+ .InSequence(s1, s2);
EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2);
FlushClient();
+ // Signal that we already released and cleared callbacks for |apiDevice|
+ DefaultApiDeviceWasReleased();
+
// Using the command encoder should be an error.
wgpuCommandEncoderFinish(encoder, nullptr);
FlushClient(false);
@@ -82,8 +91,17 @@
// The device and child objects alre also released.
EXPECT_CALL(api, BufferRelease(apiBuffer)).InSequence(s1);
EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2);
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+ .Times(1)
+ .InSequence(s1, s2);
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+ .Times(1)
+ .InSequence(s1, s2);
EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2);
FlushClient();
+
+ // Signal that we already released and cleared callbacks for |apiDevice|
+ DefaultApiDeviceWasReleased();
}
}
diff --git a/src/tests/unittests/wire/WireDisconnectTests.cpp b/src/tests/unittests/wire/WireDisconnectTests.cpp
index f44df13..d3f65a9 100644
--- a/src/tests/unittests/wire/WireDisconnectTests.cpp
+++ b/src/tests/unittests/wire/WireDisconnectTests.cpp
@@ -149,6 +149,15 @@
EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1).InSequence(s1);
EXPECT_CALL(api, CommandEncoderRelease(apiCommandEncoder)).Times(1).InSequence(s2);
EXPECT_CALL(api, SamplerRelease(apiSampler)).Times(1).InSequence(s3);
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+ .Times(1)
+ .InSequence(s1, s2);
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+ .Times(1)
+ .InSequence(s1, s2);
EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1).InSequence(s1, s2, s3);
FlushClient();
+
+ // Signal that we already released and cleared callbacks for |apiDevice|
+ DefaultApiDeviceWasReleased();
}
diff --git a/src/tests/unittests/wire/WireInjectDeviceTests.cpp b/src/tests/unittests/wire/WireInjectDeviceTests.cpp
new file mode 100644
index 0000000..8f1dda3
--- /dev/null
+++ b/src/tests/unittests/wire/WireInjectDeviceTests.cpp
@@ -0,0 +1,184 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tests/unittests/wire/WireTest.h"
+
+#include "dawn_wire/WireClient.h"
+#include "dawn_wire/WireServer.h"
+
+using namespace testing;
+using namespace dawn_wire;
+
+class WireInjectDeviceTests : public WireTest {
+ public:
+ WireInjectDeviceTests() {
+ }
+ ~WireInjectDeviceTests() override = default;
+};
+
+// Test that reserving and injecting a device makes calls on the client object forward to the
+// server object correctly.
+TEST_F(WireInjectDeviceTests, CallAfterReserveInject) {
+ ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+ WGPUDevice serverDevice = api.GetNewDevice();
+ EXPECT_CALL(api, DeviceReference(serverDevice));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
+ ASSERT_TRUE(
+ GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+ WGPUBufferDescriptor bufferDesc = {};
+ wgpuDeviceCreateBuffer(reservation.device, &bufferDesc);
+ WGPUBuffer serverBuffer = api.GetNewBuffer();
+ EXPECT_CALL(api, DeviceCreateBuffer(serverDevice, _)).WillOnce(Return(serverBuffer));
+ FlushClient();
+
+ // Called on shutdown.
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
+ .Times(Exactly(1));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
+ .Times(Exactly(1));
+}
+
+// Test that reserve correctly returns different IDs each time.
+TEST_F(WireInjectDeviceTests, ReserveDifferentIDs) {
+ ReservedDevice reservation1 = GetWireClient()->ReserveDevice();
+ ReservedDevice reservation2 = GetWireClient()->ReserveDevice();
+
+ ASSERT_NE(reservation1.id, reservation2.id);
+ ASSERT_NE(reservation1.device, reservation2.device);
+}
+
+// Test that injecting the same id without a destroy first fails.
+TEST_F(WireInjectDeviceTests, InjectExistingID) {
+ ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+ WGPUDevice serverDevice = api.GetNewDevice();
+ EXPECT_CALL(api, DeviceReference(serverDevice));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
+ ASSERT_TRUE(
+ GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+ // ID already in use, call fails.
+ ASSERT_FALSE(
+ GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+ // Called on shutdown.
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
+ .Times(Exactly(1));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
+ .Times(Exactly(1));
+}
+
+// Test that the server only borrows the device and does a single reference-release
+TEST_F(WireInjectDeviceTests, InjectedDeviceLifetime) {
+ ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+ // Injecting the device adds a reference
+ WGPUDevice serverDevice = api.GetNewDevice();
+ EXPECT_CALL(api, DeviceReference(serverDevice));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
+ ASSERT_TRUE(
+ GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+ // Releasing the device removes a single reference and clears its error callbacks.
+ wgpuDeviceRelease(reservation.device);
+ EXPECT_CALL(api, DeviceRelease(serverDevice));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr)).Times(1);
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr)).Times(1);
+ FlushClient();
+
+ // Deleting the server doesn't release a second reference.
+ DeleteServer();
+ Mock::VerifyAndClearExpectations(&api);
+}
+
+// Test that it is an error to get the default queue of a device before it has been
+// injected on the server.
+TEST_F(WireInjectDeviceTests, GetQueueBeforeInject) {
+ ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+ wgpuDeviceGetDefaultQueue(reservation.device);
+ FlushClient(false);
+}
+
+// Test that it is valid to get the default queue of a device after it has been
+// injected on the server.
+TEST_F(WireInjectDeviceTests, GetQueueAfterInject) {
+ ReservedDevice reservation = GetWireClient()->ReserveDevice();
+
+ WGPUDevice serverDevice = api.GetNewDevice();
+ EXPECT_CALL(api, DeviceReference(serverDevice));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
+ ASSERT_TRUE(
+ GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
+
+ wgpuDeviceGetDefaultQueue(reservation.device);
+
+ WGPUQueue apiQueue = api.GetNewQueue();
+ EXPECT_CALL(api, DeviceGetDefaultQueue(serverDevice)).WillOnce(Return(apiQueue));
+ FlushClient();
+
+ // Called on shutdown.
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
+ .Times(Exactly(1));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
+ .Times(Exactly(1));
+}
+
+// Test that the list of live devices can be reflected using GetDevice.
+TEST_F(WireInjectDeviceTests, ReflectLiveDevices) {
+ // Reserve two devices.
+ ReservedDevice reservation1 = GetWireClient()->ReserveDevice();
+ ReservedDevice reservation2 = GetWireClient()->ReserveDevice();
+
+ // Inject both devices.
+
+ WGPUDevice serverDevice1 = api.GetNewDevice();
+ EXPECT_CALL(api, DeviceReference(serverDevice1));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, _, _));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, _, _));
+ ASSERT_TRUE(
+ GetWireServer()->InjectDevice(serverDevice1, reservation1.id, reservation1.generation));
+
+ WGPUDevice serverDevice2 = api.GetNewDevice();
+ EXPECT_CALL(api, DeviceReference(serverDevice2));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, _, _));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, _, _));
+ ASSERT_TRUE(
+ GetWireServer()->InjectDevice(serverDevice2, reservation2.id, reservation2.generation));
+
+ // Test that both devices can be reflected.
+ ASSERT_EQ(serverDevice1, GetWireServer()->GetDevice(reservation1.id, reservation1.generation));
+ ASSERT_EQ(serverDevice2, GetWireServer()->GetDevice(reservation2.id, reservation2.generation));
+
+ // Release the first device
+ wgpuDeviceRelease(reservation1.device);
+ EXPECT_CALL(api, DeviceRelease(serverDevice1));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, nullptr, nullptr)).Times(1);
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, nullptr, nullptr)).Times(1);
+ FlushClient();
+
+ // The first device should no longer reflect, but the second should
+ ASSERT_EQ(nullptr, GetWireServer()->GetDevice(reservation1.id, reservation1.generation));
+ ASSERT_EQ(serverDevice2, GetWireServer()->GetDevice(reservation2.id, reservation2.generation));
+
+ // Called on shutdown.
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, nullptr, nullptr)).Times(1);
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, nullptr, nullptr)).Times(1);
+}
diff --git a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
index 75cedcd..216c122 100644
--- a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
+++ b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
@@ -83,9 +83,10 @@
// These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed.
- EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr))
+ EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(mServerDevice, nullptr, nullptr))
.Times(Exactly(1));
- EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
+ EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(mServerDevice, nullptr, nullptr))
+ .Times(Exactly(1));
mWireServer = nullptr;
}
diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp
index 2609511..d3f17c3 100644
--- a/src/tests/unittests/wire/WireTest.cpp
+++ b/src/tests/unittests/wire/WireTest.cpp
@@ -88,15 +88,23 @@
api.IgnoreAllReleaseCalls();
mWireClient = nullptr;
- if (mWireServer) {
+ if (mWireServer && apiDevice) {
// These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed.
- EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1));
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+ .Times(Exactly(1));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+ .Times(Exactly(1));
}
mWireServer = nullptr;
}
+// This should be called if |apiDevice| is no longer exists on the wire.
+// This signals that expectations in |TearDowb| shouldn't be added.
+void WireTest::DefaultApiDeviceWasReleased() {
+ apiDevice = nullptr;
+}
+
void WireTest::FlushClient(bool success) {
ASSERT_EQ(mC2sBuf->Flush(), success);
@@ -123,8 +131,10 @@
if (mWireServer) {
// These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed.
- EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1));
- EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
+ EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
+ .Times(Exactly(1));
+ EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
+ .Times(Exactly(1));
}
mWireServer = nullptr;
}
diff --git a/src/tests/unittests/wire/WireTest.h b/src/tests/unittests/wire/WireTest.h
index 95fd307..03ac641 100644
--- a/src/tests/unittests/wire/WireTest.h
+++ b/src/tests/unittests/wire/WireTest.h
@@ -123,6 +123,8 @@
void FlushClient(bool success = true);
void FlushServer(bool success = true);
+ void DefaultApiDeviceWasReleased();
+
testing::StrictMock<MockProcTable> api;
WGPUDevice apiDevice;
WGPUQueue apiQueue;