[dawn][wire] Make the server use a weak_ptr for callbacks.
- This is necessary to avoid raw_ptr exceptions. Note that before the
pointer was valid because of the alive-ness check through the
weak_ptr<bool>, but could potentially race because the server could
in theory be dropped between the check. This change also fixes that.
Bug: dawn:2450
Change-Id: I82755ece9d0e9cbe38653d48d762d17815ecf525
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/178740
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Loko Kung <lokokung@google.com>
diff --git a/include/dawn/wire/WireServer.h b/include/dawn/wire/WireServer.h
index 1c67262..a3b9972 100644
--- a/include/dawn/wire/WireServer.h
+++ b/include/dawn/wire/WireServer.h
@@ -73,7 +73,7 @@
bool IsDeviceKnown(WGPUDevice device) const;
private:
- std::unique_ptr<server::Server> mImpl;
+ std::shared_ptr<server::Server> mImpl;
};
namespace server {
diff --git a/src/dawn/wire/WireServer.cpp b/src/dawn/wire/WireServer.cpp
index 02067d6..9437e54 100644
--- a/src/dawn/wire/WireServer.cpp
+++ b/src/dawn/wire/WireServer.cpp
@@ -31,9 +31,9 @@
namespace dawn::wire {
WireServer::WireServer(const WireServerDescriptor& descriptor)
- : mImpl(new server::Server(*descriptor.procs,
- descriptor.serializer,
- descriptor.memoryTransferService)) {}
+ : mImpl(server::Server::Create(*descriptor.procs,
+ descriptor.serializer,
+ descriptor.memoryTransferService)) {}
WireServer::~WireServer() {
mImpl.reset();
diff --git a/src/dawn/wire/server/Server.cpp b/src/dawn/wire/server/Server.cpp
index 34cddbf..bf04246 100644
--- a/src/dawn/wire/server/Server.cpp
+++ b/src/dawn/wire/server/Server.cpp
@@ -30,16 +30,21 @@
namespace dawn::wire::server {
-CallbackUserdata::CallbackUserdata(Server* server, const std::shared_ptr<bool>& serverIsAlive)
- : server(server), serverIsAlive(serverIsAlive) {}
+CallbackUserdata::CallbackUserdata(const std::weak_ptr<Server>& server) : server(server) {}
+
+// static
+std::shared_ptr<Server> Server::Create(const DawnProcTable& procs,
+ CommandSerializer* serializer,
+ MemoryTransferService* memoryTransferService) {
+ auto server = std::shared_ptr<Server>(new Server(procs, serializer, memoryTransferService));
+ server->mSelf = server;
+ return server;
+}
Server::Server(const DawnProcTable& procs,
CommandSerializer* serializer,
MemoryTransferService* memoryTransferService)
- : mSerializer(serializer),
- mProcs(procs),
- mMemoryTransferService(memoryTransferService),
- mIsAlive(std::make_shared<bool>(true)) {
+ : mSerializer(serializer), mProcs(procs), mMemoryTransferService(memoryTransferService) {
if (mMemoryTransferService == nullptr) {
// If a MemoryTransferService is not provided, fallback to inline memory.
mOwnedMemoryTransferService = CreateInlineMemoryTransferService();
diff --git a/src/dawn/wire/server/Server.h b/src/dawn/wire/server/Server.h
index ccd0a72..1e15f1b 100644
--- a/src/dawn/wire/server/Server.h
+++ b/src/dawn/wire/server/Server.h
@@ -66,11 +66,10 @@
//
// void Server::MyCallbackHandler(MyUserdata* userdata, Other args) { }
struct CallbackUserdata {
- const raw_ptr<Server> server;
- std::weak_ptr<bool> const serverIsAlive;
+ const std::weak_ptr<Server> server;
CallbackUserdata() = delete;
- CallbackUserdata(Server* server, const std::shared_ptr<bool>& serverIsAlive);
+ explicit CallbackUserdata(const std::weak_ptr<Server>& server);
};
template <auto F>
@@ -85,12 +84,13 @@
static Return Callback(Args... args, void* userdata) {
// Acquire the userdata, and cast it to UserdataT.
std::unique_ptr<Userdata> data(static_cast<Userdata*>(userdata));
- if (data->serverIsAlive.expired()) {
+ auto server = data->server.lock();
+ if (!server) {
// Do nothing if the server has already been destroyed.
return;
}
// Forward the arguments and the typed userdata to the Server:: member function.
- (data->server->*F)(data.get(), std::forward<decltype(args)>(args)...);
+ (server.get()->*F)(data.get(), std::forward<decltype(args)>(args)...);
}
};
@@ -164,9 +164,9 @@
class Server : public ServerBase {
public:
- Server(const DawnProcTable& procs,
- CommandSerializer* serializer,
- MemoryTransferService* memoryTransferService);
+ static std::shared_ptr<Server> Create(const DawnProcTable& procs,
+ CommandSerializer* serializer,
+ MemoryTransferService* memoryTransferService);
~Server() override;
// ChunkedCommandHandler implementation
@@ -184,10 +184,14 @@
template <typename T,
typename Enable = std::enable_if<std::is_base_of<CallbackUserdata, T>::value>>
std::unique_ptr<T> MakeUserdata() {
- return std::unique_ptr<T>(new T(this, mIsAlive));
+ return std::unique_ptr<T>(new T(mSelf));
}
private:
+ Server(const DawnProcTable& procs,
+ CommandSerializer* serializer,
+ MemoryTransferService* memoryTransferService);
+
template <typename Cmd>
void SerializeCommand(const Cmd& cmd) {
mSerializer->SerializeCommand(cmd);
@@ -238,7 +242,8 @@
std::unique_ptr<MemoryTransferService> mOwnedMemoryTransferService = nullptr;
raw_ptr<MemoryTransferService> mMemoryTransferService = nullptr;
- std::shared_ptr<bool> mIsAlive;
+ // Weak pointer to self to facilitate creation of userdata.
+ std::weak_ptr<Server> mSelf;
};
std::unique_ptr<MemoryTransferService> CreateInlineMemoryTransferService();