dawn::wire::client: Add Ref<T> and use it for refcounting.
Bug: 344963953
Change-Id: Ibb9f29534a6cb3f51c95bb2aad96fd10e512f2c2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/195774
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/generator/templates/dawn/wire/client/ApiProcs.cpp b/generator/templates/dawn/wire/client/ApiProcs.cpp
index f529832..623d228 100644
--- a/generator/templates/dawn/wire/client/ApiProcs.cpp
+++ b/generator/templates/dawn/wire/client/ApiProcs.cpp
@@ -40,14 +40,14 @@
template <typename Parent, typename Child, typename... Args>
Child* Create(Parent p, Args... args) {
if constexpr (std::is_constructible_v<Child, const ObjectBaseParams&, decltype(args)...>) {
- return p->GetClient()->template Make<Child>(args...);
+ return p->GetClient()->template Make<Child>(args...).Detach();
} else if constexpr (std::is_constructible_v<Child, const ObjectBaseParams&, const ObjectHandle&, decltype(args)...>) {
- return p->GetClient()->template Make<Child>(p->GetEventManagerHandle(), args...);
+ return p->GetClient()->template Make<Child>(p->GetEventManagerHandle(), args...).Detach();
} else {
if constexpr (std::is_base_of_v<ObjectWithEventsBase, Child>) {
- return p->GetClient()->template Make<Child>(p->GetEventManagerHandle());
+ return p->GetClient()->template Make<Child>(p->GetEventManagerHandle()).Detach();
} else {
- return p->GetClient()->template Make<Child>();
+ return p->GetClient()->template Make<Child>().Detach();
}
}
}
diff --git a/src/dawn/wire/client/Adapter.cpp b/src/dawn/wire/client/Adapter.cpp
index 548df55..3b79825 100644
--- a/src/dawn/wire/client/Adapter.cpp
+++ b/src/dawn/wire/client/Adapter.cpp
@@ -29,6 +29,7 @@
#include <memory>
#include <string>
+#include <utility>
#include "dawn/common/Log.h"
#include "dawn/wire/client/Client.h"
@@ -42,18 +43,18 @@
public:
static constexpr EventType kType = EventType::RequestDevice;
- RequestDeviceEvent(const WGPURequestDeviceCallbackInfo& callbackInfo, Device* device)
+ RequestDeviceEvent(const WGPURequestDeviceCallbackInfo& callbackInfo, Ref<Device> device)
: TrackedEvent(callbackInfo.mode),
mCallback(callbackInfo.callback),
mUserdata1(callbackInfo.userdata),
- mDevice(device) {}
+ mDevice(std::move(device)) {}
- RequestDeviceEvent(const WGPURequestDeviceCallbackInfo2& callbackInfo, Device* device)
+ RequestDeviceEvent(const WGPURequestDeviceCallbackInfo2& callbackInfo, Ref<Device> device)
: TrackedEvent(callbackInfo.mode),
mCallback2(callbackInfo.callback),
mUserdata1(callbackInfo.userdata1),
mUserdata2(callbackInfo.userdata2),
- mDevice(device) {}
+ mDevice(std::move(device)) {}
EventType GetType() override { return kType; }
@@ -82,14 +83,14 @@
mMessage = "A valid external Instance reference no longer exists.";
}
- Device* device = mDevice.ExtractAsDangling();
// Callback needs to happen before device lost handling to ensure resolution order.
if (mCallback) {
- mCallback(mStatus, ToAPI(mStatus == WGPURequestDeviceStatus_Success ? device : nullptr),
+ mCallback(mStatus,
+ mStatus == WGPURequestDeviceStatus_Success ? ReturnToAPI(mDevice) : nullptr,
mMessage ? mMessage->c_str() : nullptr, mUserdata1.ExtractAsDangling());
} else if (mCallback2) {
mCallback2(mStatus,
- ToAPI(mStatus == WGPURequestDeviceStatus_Success ? device : nullptr),
+ mStatus == WGPURequestDeviceStatus_Success ? ReturnToAPI(mDevice) : nullptr,
mMessage ? mMessage->c_str() : nullptr, mUserdata1.ExtractAsDangling(),
mUserdata2.ExtractAsDangling());
}
@@ -98,17 +99,16 @@
// If there was an error and we didn't return a device, we need to call the device lost
// callback and reclaim the device allocation.
if (mStatus == WGPURequestDeviceStatus_InstanceDropped) {
- device->HandleDeviceLost(WGPUDeviceLostReason_InstanceDropped,
- "A valid external Instance reference no longer exists.");
+ mDevice->HandleDeviceLost(WGPUDeviceLostReason_InstanceDropped,
+ "A valid external Instance reference no longer exists.");
} else {
- device->HandleDeviceLost(WGPUDeviceLostReason_FailedCreation,
- "Device failed at creation.");
+ mDevice->HandleDeviceLost(WGPUDeviceLostReason_FailedCreation,
+ "Device failed at creation.");
}
}
if (mCallback == nullptr && mCallback2 == nullptr) {
// If there's no callback, clean up the resources.
- device->Release();
mUserdata1.ExtractAsDangling();
mUserdata2.ExtractAsDangling();
}
@@ -128,7 +128,7 @@
// throughout the duration of a RequestDeviceEvent because the Event essentially takes
// ownership of it until either an error occurs at which point the Event cleans it up, or it
// returns the device to the user who then takes ownership as the Event goes away.
- raw_ptr<Device> mDevice = nullptr;
+ Ref<Device> mDevice;
};
} // anonymous namespace
@@ -366,7 +366,7 @@
WGPUFuture Adapter::RequestDeviceF(const WGPUDeviceDescriptor* descriptor,
const WGPURequestDeviceCallbackInfo& callbackInfo) {
Client* client = GetClient();
- Device* device = client->Make<Device>(GetEventManagerHandle(), descriptor);
+ Ref<Device> device = client->Make<Device>(GetEventManagerHandle(), descriptor);
auto [futureIDInternal, tracked] =
GetEventManager().TrackEvent(std::make_unique<RequestDeviceEvent>(callbackInfo, device));
if (!tracked) {
@@ -402,7 +402,7 @@
WGPUFuture Adapter::RequestDevice2(const WGPUDeviceDescriptor* descriptor,
const WGPURequestDeviceCallbackInfo2& callbackInfo) {
Client* client = GetClient();
- Device* device = client->Make<Device>(GetEventManagerHandle(), descriptor);
+ Ref<Device> device = client->Make<Device>(GetEventManagerHandle(), descriptor);
auto [futureIDInternal, tracked] =
GetEventManager().TrackEvent(std::make_unique<RequestDeviceEvent>(callbackInfo, device));
if (!tracked) {
diff --git a/src/dawn/wire/client/Buffer.cpp b/src/dawn/wire/client/Buffer.cpp
index c9b1a6f..3b3412d 100644
--- a/src/dawn/wire/client/Buffer.cpp
+++ b/src/dawn/wire/client/Buffer.cpp
@@ -58,17 +58,14 @@
public:
static constexpr EventType kType = EventType::MapAsync;
- MapAsyncEvent(const WGPUBufferMapCallbackInfo& callbackInfo, Buffer* buffer)
+ MapAsyncEvent(const WGPUBufferMapCallbackInfo& callbackInfo, Ref<Buffer> buffer)
: TrackedEvent(callbackInfo.mode),
mCallback(callbackInfo.callback),
mUserdata(callbackInfo.userdata),
- mBuffer(buffer) {
- DAWN_ASSERT(buffer != nullptr);
- mBuffer->AddRef();
+ mBuffer(std::move(buffer)) {
+ DAWN_ASSERT(mBuffer != nullptr);
}
- ~MapAsyncEvent() override { mBuffer.ExtractAsDangling()->Release(); }
-
EventType GetType() override { return kType; }
bool IsPendingRequest(FutureID futureID) {
@@ -207,25 +204,22 @@
std::optional<WGPUBufferMapAsyncStatus> mStatus;
// Strong reference to the buffer so that when we call the callback we can pass the buffer.
- raw_ptr<Buffer> mBuffer;
+ Ref<Buffer> mBuffer;
};
class Buffer::MapAsyncEvent2 : public TrackedEvent {
public:
static constexpr EventType kType = EventType::MapAsync;
- MapAsyncEvent2(const WGPUBufferMapCallbackInfo2& callbackInfo, Buffer* buffer)
+ MapAsyncEvent2(const WGPUBufferMapCallbackInfo2& callbackInfo, Ref<Buffer> buffer)
: TrackedEvent(callbackInfo.mode),
mCallback(callbackInfo.callback),
mUserdata1(callbackInfo.userdata1),
mUserdata2(callbackInfo.userdata2),
mBuffer(buffer) {
- DAWN_ASSERT(buffer != nullptr);
- mBuffer->AddRef();
+ DAWN_ASSERT(mBuffer != nullptr);
}
- ~MapAsyncEvent2() override { mBuffer.ExtractAsDangling()->Release(); }
-
EventType GetType() override { return kType; }
bool IsPendingRequest(FutureID futureID) {
@@ -339,7 +333,7 @@
std::optional<std::string> mMessage;
// Strong reference to the buffer so that when we call the callback we can pass the buffer.
- raw_ptr<Buffer> mBuffer;
+ Ref<Buffer> mBuffer;
};
// static
@@ -393,7 +387,7 @@
// Create the buffer and send the creation command.
// This must happen after any potential error buffer creation
// as server expects allocating ids to be monotonically increasing
- Buffer* buffer = wireClient->Make<Buffer>(device->GetEventManagerHandle(), descriptor);
+ Ref<Buffer> buffer = wireClient->Make<Buffer>(device->GetEventManagerHandle(), descriptor);
buffer->mIsDeviceAlive = device->GetAliveWeakPtr();
if (descriptor->mappedAtCreation) {
@@ -431,7 +425,7 @@
}
}});
// clang-format on
- return ToAPI(buffer);
+ return ReturnToAPI(std::move(buffer));
}
Buffer::Buffer(const ObjectBaseParams& params,
diff --git a/src/dawn/wire/client/Client.cpp b/src/dawn/wire/client/Client.cpp
index bacb0f1..a24445b 100644
--- a/src/dawn/wire/client/Client.cpp
+++ b/src/dawn/wire/client/Client.cpp
@@ -100,41 +100,40 @@
}
ReservedBuffer Client::ReserveBuffer(WGPUDevice device, const WGPUBufferDescriptor* descriptor) {
- Buffer* buffer = Make<Buffer>(FromAPI(device)->GetEventManagerHandle(), descriptor);
+ Ref<Buffer> buffer = Make<Buffer>(FromAPI(device)->GetEventManagerHandle(), descriptor);
ReservedBuffer result;
- result.buffer = ToAPI(buffer);
result.handle = buffer->GetWireHandle();
result.deviceHandle = FromAPI(device)->GetWireHandle();
+ result.buffer = ReturnToAPI(buffer);
return result;
}
ReservedTexture Client::ReserveTexture(WGPUDevice device, const WGPUTextureDescriptor* descriptor) {
- Texture* texture = Make<Texture>(descriptor);
+ Ref<Texture> texture = Make<Texture>(descriptor);
ReservedTexture result;
- result.texture = ToAPI(texture);
result.handle = texture->GetWireHandle();
result.deviceHandle = FromAPI(device)->GetWireHandle();
+ result.texture = ReturnToAPI(texture);
return result;
}
ReservedSwapChain Client::ReserveSwapChain(WGPUDevice device,
const WGPUSwapChainDescriptor* descriptor) {
- SwapChain* swapChain = Make<SwapChain>(nullptr, descriptor);
+ Ref<SwapChain> swapChain = Make<SwapChain>(nullptr, descriptor);
ReservedSwapChain result;
- result.swapchain = ToAPI(swapChain);
result.handle = swapChain->GetWireHandle();
result.deviceHandle = FromAPI(device)->GetWireHandle();
+ result.swapchain = ReturnToAPI(swapChain);
return result;
}
ReservedInstance Client::ReserveInstance(const WGPUInstanceDescriptor* descriptor) {
- Instance* instance = Make<Instance>();
+ Ref<Instance> instance = Make<Instance>();
if (instance->Initialize(descriptor) != WireResult::Success) {
- Free(instance);
return {nullptr, {0, 0}};
}
@@ -142,8 +141,8 @@
mEventManagers.emplace(instance->GetWireHandle(), std::make_unique<EventManager>());
ReservedInstance result;
- result.instance = ToAPI(instance);
result.handle = instance->GetWireHandle();
+ result.instance = ReturnToAPI(instance);
return result;
}
diff --git a/src/dawn/wire/client/Client.h b/src/dawn/wire/client/Client.h
index d389bd5..af8191b 100644
--- a/src/dawn/wire/client/Client.h
+++ b/src/dawn/wire/client/Client.h
@@ -60,7 +60,7 @@
//
// T::T(ObjectBaseParams, arg1, arg2, arg3)
template <typename T, typename... Args>
- T* Make(Args&&... args) {
+ Ref<T> Make(Args&&... args) {
constexpr ObjectType type = ObjectTypeToTypeEnum<T>;
ObjectBaseParams params = {this, mObjectStores[type].ReserveHandle()};
@@ -68,7 +68,10 @@
mObjects[type].Append(object);
mObjectStores[type].Insert(std::unique_ptr<T>(object));
- return object;
+
+ Ref<T> ref;
+ ref.Acquire(object);
+ return ref;
}
void Free(ObjectBase* obj, ObjectType type);
diff --git a/src/dawn/wire/client/Device.cpp b/src/dawn/wire/client/Device.cpp
index c7337f6..6c2f64d 100644
--- a/src/dawn/wire/client/Device.cpp
+++ b/src/dawn/wire/client/Device.cpp
@@ -92,12 +92,12 @@
static constexpr EventType kType = Type;
- CreatePipelineEventBase(const CallbackInfo& callbackInfo, Pipeline* pipeline)
+ CreatePipelineEventBase(const CallbackInfo& callbackInfo, Ref<Pipeline> pipeline)
: TrackedEvent(callbackInfo.mode),
mCallback(callbackInfo.callback),
mUserdata1(callbackInfo.userdata1),
mUserdata2(callbackInfo.userdata2),
- mPipeline(pipeline) {
+ mPipeline(std::move(pipeline)) {
DAWN_ASSERT(mPipeline != nullptr);
}
@@ -118,7 +118,6 @@
void CompleteImpl(FutureID futureID, EventCompletionType completionType) override {
auto userdata1 = mUserdata1.ExtractAsDangling();
auto userdata2 = mUserdata2.ExtractAsDangling();
- Pipeline* pipeline = mPipeline.ExtractAsDangling();
if (mCallback == nullptr) {
return;
@@ -129,9 +128,10 @@
mMessage = "A valid external Instance reference no longer exists.";
}
- mCallback(mStatus,
- ToAPI(mStatus == WGPUCreatePipelineAsyncStatus_Success ? pipeline : nullptr),
- mMessage ? mMessage->c_str() : nullptr, userdata1, userdata2);
+ mCallback(
+ mStatus,
+ mStatus == WGPUCreatePipelineAsyncStatus_Success ? ReturnToAPI(mPipeline) : nullptr,
+ mMessage ? mMessage->c_str() : nullptr, userdata1, userdata2);
}
using Callback = decltype(std::declval<CallbackInfo>().callback);
@@ -144,7 +144,7 @@
WGPUCreatePipelineAsyncStatus mStatus = WGPUCreatePipelineAsyncStatus_Success;
std::optional<std::string> mMessage;
- raw_ptr<Pipeline> mPipeline = nullptr;
+ Ref<Pipeline> mPipeline;
};
using CreateComputePipelineEvent =
@@ -201,18 +201,15 @@
public:
static constexpr EventType kType = EventType::DeviceLost;
- DeviceLostEvent(const WGPUDeviceLostCallbackInfo2& callbackInfo, Device* device)
- : TrackedEvent(callbackInfo.mode), mDevice(device) {
- DAWN_ASSERT(device != nullptr);
- mDevice->AddRef();
+ DeviceLostEvent(const WGPUDeviceLostCallbackInfo2& callbackInfo, Ref<Device> device)
+ : TrackedEvent(callbackInfo.mode), mDevice(std::move(device)) {
+ DAWN_ASSERT(mDevice != nullptr);
mDevice->mDeviceLostInfo.callback = callbackInfo.callback;
mDevice->mDeviceLostInfo.userdata1 = callbackInfo.userdata1;
mDevice->mDeviceLostInfo.userdata2 = callbackInfo.userdata2;
}
- ~DeviceLostEvent() override { mDevice.ExtractAsDangling()->Release(); }
-
EventType GetType() override { return kType; }
WireResult ReadyHook(FutureID futureID, WGPUDeviceLostReason reason, const char* message) {
@@ -235,7 +232,8 @@
void* userdata2 = mDevice->mDeviceLostInfo.userdata2.ExtractAsDangling();
if (mDevice->mDeviceLostInfo.callback != nullptr) {
- auto device = mReason != WGPUDeviceLostReason_FailedCreation ? ToAPI(mDevice) : nullptr;
+ const auto device =
+ mReason != WGPUDeviceLostReason_FailedCreation ? ToAPI(mDevice.Get()) : nullptr;
mDevice->mDeviceLostInfo.callback(
&device, mReason, mMessage ? mMessage->c_str() : nullptr, userdata1, userdata2);
}
@@ -248,7 +246,7 @@
std::optional<std::string> mMessage;
// Strong reference to the device so that when we call the callback we can pass the device.
- raw_ptr<Device> mDevice;
+ Ref<Device> mDevice;
};
Device::Device(const ObjectBaseParams& params,
@@ -318,12 +316,6 @@
}
}
-Device::~Device() {
- if (mQueue != nullptr) {
- mQueue.ExtractAsDangling()->Release();
- }
-}
-
ObjectType Device::GetObjectType() const {
return ObjectType::Device;
}
@@ -361,7 +353,7 @@
void Device::HandleError(WGPUErrorType errorType, const char* message) {
if (mUncapturedErrorCallbackInfo.callback) {
- auto device = ToAPI(this);
+ const auto device = ToAPI(this);
mUncapturedErrorCallbackInfo.callback(&device, errorType, message,
mUncapturedErrorCallbackInfo.userdata1,
mUncapturedErrorCallbackInfo.userdata2);
@@ -503,8 +495,7 @@
client->SerializeCommand(cmd);
}
- mQueue->AddRef();
- return ToAPI(mQueue);
+ return ReturnToAPI(mQueue);
}
template <typename Event, typename Cmd, typename CallbackInfo, typename Descriptor>
@@ -513,7 +504,7 @@
using Pipeline = typename Event::Pipeline;
Client* client = GetClient();
- Pipeline* pipeline = client->Make<Pipeline>();
+ Ref<Pipeline> pipeline = client->Make<Pipeline>();
auto [futureIDInternal, tracked] =
GetEventManager().TrackEvent(std::make_unique<Event>(callbackInfo, pipeline));
if (!tracked) {
diff --git a/src/dawn/wire/client/Device.h b/src/dawn/wire/client/Device.h
index 2a2d88e..6212e1d 100644
--- a/src/dawn/wire/client/Device.h
+++ b/src/dawn/wire/client/Device.h
@@ -48,7 +48,6 @@
explicit Device(const ObjectBaseParams& params,
const ObjectHandle& eventManagerHandle,
const WGPUDeviceDescriptor* descriptor);
- ~Device() override;
ObjectType GetObjectType() const override;
@@ -123,7 +122,7 @@
WGPULoggingCallback mLoggingCallback = nullptr;
raw_ptr<void> mLoggingUserdata = nullptr;
- raw_ptr<Queue> mQueue = nullptr;
+ Ref<Queue> mQueue;
std::shared_ptr<bool> mIsAlive;
};
diff --git a/src/dawn/wire/client/Instance.cpp b/src/dawn/wire/client/Instance.cpp
index e5ca6e9..85b49b4 100644
--- a/src/dawn/wire/client/Instance.cpp
+++ b/src/dawn/wire/client/Instance.cpp
@@ -48,18 +48,18 @@
public:
static constexpr EventType kType = EventType::RequestAdapter;
- RequestAdapterEvent(const WGPURequestAdapterCallbackInfo& callbackInfo, Adapter* adapter)
+ RequestAdapterEvent(const WGPURequestAdapterCallbackInfo& callbackInfo, Ref<Adapter> adapter)
: TrackedEvent(callbackInfo.mode),
mCallback(callbackInfo.callback),
mUserdata1(callbackInfo.userdata),
- mAdapter(adapter) {}
+ mAdapter(std::move(adapter)) {}
- RequestAdapterEvent(const WGPURequestAdapterCallbackInfo2& callbackInfo, Adapter* adapter)
+ RequestAdapterEvent(const WGPURequestAdapterCallbackInfo2& callbackInfo, Ref<Adapter> adapter)
: TrackedEvent(callbackInfo.mode),
mCallback2(callbackInfo.callback),
mUserdata1(callbackInfo.userdata1),
mUserdata2(callbackInfo.userdata2),
- mAdapter(adapter) {}
+ mAdapter(std::move(adapter)) {}
EventType GetType() override { return kType; }
@@ -89,7 +89,6 @@
void CompleteImpl(FutureID futureID, EventCompletionType completionType) override {
if (mCallback == nullptr && mCallback2 == nullptr) {
// If there's no callback, just clean up the resources.
- mAdapter.ExtractAsDangling()->Release();
mUserdata1.ExtractAsDangling();
mUserdata2.ExtractAsDangling();
return;
@@ -100,16 +99,16 @@
mMessage = "A valid external Instance reference no longer exists.";
}
- Adapter* adapter = mAdapter.ExtractAsDangling();
if (mCallback) {
mCallback(mStatus,
- ToAPI(mStatus == WGPURequestAdapterStatus_Success ? adapter : nullptr),
+ mStatus == WGPURequestAdapterStatus_Success ? ReturnToAPI(mAdapter) : nullptr,
mMessage ? mMessage->c_str() : nullptr, mUserdata1.ExtractAsDangling());
} else {
- mCallback2(mStatus,
- ToAPI(mStatus == WGPURequestAdapterStatus_Success ? adapter : nullptr),
- mMessage ? mMessage->c_str() : nullptr, mUserdata1.ExtractAsDangling(),
- mUserdata2.ExtractAsDangling());
+ mCallback2(
+ mStatus,
+ mStatus == WGPURequestAdapterStatus_Success ? ReturnToAPI(mAdapter) : nullptr,
+ mMessage ? mMessage->c_str() : nullptr, mUserdata1.ExtractAsDangling(),
+ mUserdata2.ExtractAsDangling());
}
}
@@ -127,7 +126,7 @@
// throughout the duration of a RequestAdapterEvent because the Event essentially takes
// ownership of it until either an error occurs at which point the Event cleans it up, or it
// returns the adapter to the user who then takes ownership as the Event goes away.
- raw_ptr<Adapter> mAdapter = nullptr;
+ Ref<Adapter> mAdapter;
};
WGPUWGSLFeatureName ToWGPUFeature(tint::wgsl::LanguageFeature f) {
@@ -207,7 +206,7 @@
WGPUFuture Instance::RequestAdapterF(const WGPURequestAdapterOptions* options,
const WGPURequestAdapterCallbackInfo& callbackInfo) {
Client* client = GetClient();
- Adapter* adapter = client->Make<Adapter>(GetEventManagerHandle());
+ Ref<Adapter> adapter = client->Make<Adapter>(GetEventManagerHandle());
auto [futureIDInternal, tracked] =
GetEventManager().TrackEvent(std::make_unique<RequestAdapterEvent>(callbackInfo, adapter));
if (!tracked) {
@@ -229,7 +228,7 @@
WGPUFuture Instance::RequestAdapter2(const WGPURequestAdapterOptions* options,
const WGPURequestAdapterCallbackInfo2& callbackInfo) {
Client* client = GetClient();
- Adapter* adapter = client->Make<Adapter>(GetEventManagerHandle());
+ Ref<Adapter> adapter = client->Make<Adapter>(GetEventManagerHandle());
auto [futureIDInternal, tracked] =
GetEventManager().TrackEvent(std::make_unique<RequestAdapterEvent>(callbackInfo, adapter));
if (!tracked) {
diff --git a/src/dawn/wire/client/ObjectBase.h b/src/dawn/wire/client/ObjectBase.h
index 58649d2..467e9a2 100644
--- a/src/dawn/wire/client/ObjectBase.h
+++ b/src/dawn/wire/client/ObjectBase.h
@@ -32,6 +32,7 @@
#include "partition_alloc/pointers/raw_ptr.h"
#include "dawn/common/LinkedList.h"
+#include "dawn/common/RefBase.h"
#include "dawn/wire/ObjectHandle.h"
#include "dawn/wire/ObjectType_autogen.h"
#include "dawn/wire/client/EventManager.h"
@@ -92,6 +93,29 @@
ObjectHandle mEventManagerHandle;
};
+// Ref<T> for a T that's an ObjectBase*
+namespace detail {
+
+template <typename T>
+struct ObjectBaseTraits {
+ static constexpr T* kNullValue = nullptr;
+ static void AddRef(T* value) { value->AddRef(); }
+ static void Release(T* value) { value->Release(); }
+};
+
+} // namespace detail
+
+template <typename T>
+class Ref : public RefBase<T*, detail::ObjectBaseTraits<T>> {
+ public:
+ using RefBase<T*, detail::ObjectBaseTraits<T>>::RefBase;
+};
+
+template <typename T>
+auto ReturnToAPI(Ref<T> r) {
+ return ToAPI(r.Detach());
+}
+
} // namespace dawn::wire::client
#endif // SRC_DAWN_WIRE_CLIENT_OBJECTBASE_H_
diff --git a/src/dawn/wire/client/ShaderModule.cpp b/src/dawn/wire/client/ShaderModule.cpp
index 31c4244..66951d8 100644
--- a/src/dawn/wire/client/ShaderModule.cpp
+++ b/src/dawn/wire/client/ShaderModule.cpp
@@ -28,6 +28,7 @@
#include "dawn/wire/client/ShaderModule.h"
#include <memory>
+#include <utility>
#include "dawn/wire/client/Client.h"
#include "partition_alloc/pointers/raw_ptr.h"
@@ -38,18 +39,16 @@
public:
static constexpr EventType kType = EventType::CompilationInfo;
- CompilationInfoEvent(const WGPUCompilationInfoCallbackInfo2& callbackInfo, ShaderModule* shader)
+ CompilationInfoEvent(const WGPUCompilationInfoCallbackInfo2& callbackInfo,
+ Ref<ShaderModule> shader)
: TrackedEvent(callbackInfo.mode),
mCallback(callbackInfo.callback),
mUserdata1(callbackInfo.userdata1),
mUserdata2(callbackInfo.userdata2),
- mShader(shader) {
+ mShader(std::move(shader)) {
DAWN_ASSERT(mShader != nullptr);
- mShader->AddRef();
}
- ~CompilationInfoEvent() override { mShader.ExtractAsDangling()->Release(); }
-
EventType GetType() override { return kType; }
WireResult ReadyHook(FutureID futureId,
@@ -107,7 +106,7 @@
// Strong reference to the shader so that when we call the callback we can pass the
// compilation info from `mShader`.
- raw_ptr<ShaderModule> mShader;
+ Ref<ShaderModule> mShader;
};
ObjectType ShaderModule::GetObjectType() const {
diff --git a/src/dawn/wire/client/Surface.cpp b/src/dawn/wire/client/Surface.cpp
index 0979838..00d5f8d 100644
--- a/src/dawn/wire/client/Surface.cpp
+++ b/src/dawn/wire/client/Surface.cpp
@@ -77,14 +77,15 @@
dawn::ErrorLog() << "surface.GetCurrentTexture not supported yet with dawn_wire.";
Client* wireClient = GetClient();
- Texture* texture = wireClient->Make<Texture>(&mTextureDescriptor);
- surfaceTexture->texture = ToAPI(texture);
+ Ref<Texture> texture = wireClient->Make<Texture>(&mTextureDescriptor);
SurfaceGetCurrentTextureCmd cmd;
cmd.self = ToAPI(this);
cmd.selfId = GetWireId();
// cmd.result = texture->GetWireHandle(); // TODO(dawn:2320): Feed surfaceTexture to cmd
wireClient->SerializeCommand(cmd);
+
+ surfaceTexture->texture = ReturnToAPI(texture);
}
} // namespace dawn::wire::client
diff --git a/src/dawn/wire/client/SwapChain.cpp b/src/dawn/wire/client/SwapChain.cpp
index 8b18407..b0691d0 100644
--- a/src/dawn/wire/client/SwapChain.cpp
+++ b/src/dawn/wire/client/SwapChain.cpp
@@ -54,7 +54,7 @@
WGPUTexture SwapChain::GetCurrentTexture() {
Client* wireClient = GetClient();
- Texture* texture = wireClient->Make<Texture>(&mTextureDescriptor);
+ Ref<Texture> texture = wireClient->Make<Texture>(&mTextureDescriptor);
SwapChainGetCurrentTextureCmd cmd;
cmd.self = ToAPI(this);
@@ -62,7 +62,7 @@
cmd.result = texture->GetWireHandle();
wireClient->SerializeCommand(cmd);
- return ToAPI(texture);
+ return ReturnToAPI(texture);
}
} // namespace dawn::wire::client