[WGPUFuture] Implement GetCompilationInfo in wire/native with Futures.
Bug: dawn:1987
Change-Id: Ie2cc2d5177e359467fbcde7ea648ef5f2019cd87
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/175120
Reviewed-by: Shrek Shao <shrekshao@google.com>
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/generator/templates/dawn/wire/client/ApiProcs.cpp b/generator/templates/dawn/wire/client/ApiProcs.cpp
index e07f44b..3830d73 100644
--- a/generator/templates/dawn/wire/client/ApiProcs.cpp
+++ b/generator/templates/dawn/wire/client/ApiProcs.cpp
@@ -32,6 +32,7 @@
#include <algorithm>
#include <cstring>
#include <string>
+#include <type_traits>
#include <vector>
namespace dawn::wire::client {
@@ -44,7 +45,11 @@
} else if constexpr (std::is_constructible_v<Child, const ObjectBaseParams&, const ObjectHandle&, decltype(args)...>) {
return p->GetClient()->template Make<Child>(p->GetEventManagerHandle(), args...);
} else {
- return p->GetClient()->template Make<Child>();
+ if constexpr (std::is_base_of_v<ObjectWithEventsBase, Child>) {
+ return p->GetClient()->template Make<Child>(p->GetEventManagerHandle());
+ } else {
+ return p->GetClient()->template Make<Child>();
+ }
}
}
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index d196202..2c7e22f 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -893,6 +893,15 @@
{"name": "userdata", "type": "void *"}
]
},
+ "compilation info callback info": {
+ "category": "structure",
+ "extensible": "in",
+ "members": [
+ {"name": "mode", "type": "callback mode"},
+ {"name": "callback", "type": "compilation info callback"},
+ {"name": "userdata", "type": "void *", "default": "nullptr"}
+ ]
+ },
"compilation info request status": {
"category": "enum",
"values": [
@@ -3289,6 +3298,15 @@
]
},
{
+ "name": "get compilation info f",
+ "_comment": "TODO(crbug.com/dawn/2021): This is dawn/emscripten-only until we rename it to replace the old API. See bug for details.",
+ "tags": ["dawn", "emscripten"],
+ "returns": "future",
+ "args": [
+ {"name": "callback info", "type": "compilation info callback info"}
+ ]
+ },
+ {
"name": "set label",
"returns": "void",
"args": [
diff --git a/src/dawn/dawn_wire.json b/src/dawn/dawn_wire.json
index 9365631..234bbc9 100644
--- a/src/dawn/dawn_wire.json
+++ b/src/dawn/dawn_wire.json
@@ -100,7 +100,8 @@
],
"shader module get compilation info": [
{ "name": "shader module id", "type": "ObjectId", "id_type": "shader module" },
- { "name": "request serial", "type": "uint64_t" }
+ { "name": "event manager handle", "type": "ObjectHandle" },
+ { "name": "future", "type": "future" }
],
"instance request adapter": [
{ "name": "instance id", "type": "ObjectId", "id_type": "instance" },
@@ -164,8 +165,8 @@
{ "name": "status", "type": "queue work done status" }
],
"shader module get compilation info callback": [
- { "name": "shader module", "type": "ObjectHandle", "handle_type": "shader module" },
- { "name": "request serial", "type": "uint64_t" },
+ { "name": "event manager", "type": "ObjectHandle" },
+ { "name": "future", "type": "future" },
{ "name": "status", "type": "compilation info request status" },
{ "name": "info", "type": "compilation info", "annotation": "const*", "optional": true }
],
@@ -232,6 +233,7 @@
"InstanceRequestAdapter",
"InstanceRequestAdapterF",
"ShaderModuleGetCompilationInfo",
+ "ShaderModuleGetCompilationInfoF",
"QuerySetGetType",
"QuerySetGetCount",
"QueueOnSubmittedWorkDone",
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index 2a9cc28..9d92a40 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -37,6 +37,7 @@
#include "dawn/native/ChainUtils.h"
#include "dawn/native/CompilationMessages.h"
#include "dawn/native/Device.h"
+#include "dawn/native/Instance.h"
#include "dawn/native/ObjectContentHasher.h"
#include "dawn/native/Pipeline.h"
#include "dawn/native/PipelineLayout.h"
@@ -1387,9 +1388,47 @@
if (callback == nullptr) {
return;
}
+ CompilationInfoCallbackInfo callbackInfo = {nullptr, wgpu::CallbackMode::AllowSpontaneous,
+ callback, userdata};
+ APIGetCompilationInfoF(callbackInfo);
+}
- callback(WGPUCompilationInfoRequestStatus_Success, mCompilationMessages->GetCompilationInfo(),
- userdata);
+Future ShaderModuleBase::APIGetCompilationInfoF(const CompilationInfoCallbackInfo& callbackInfo) {
+ struct CompilationInfoEvent final : public EventManager::TrackedEvent {
+ WGPUCompilationInfoCallback mCallback;
+ // TODO(https://crbug.com/dawn/2349): Investigate DanglingUntriaged in dawn/native.
+ raw_ptr<void, DanglingUntriaged> mUserdata;
+ // Need to keep a Ref of the compilation messages in case the ShaderModule goes away before
+ // the callback happens.
+ Ref<ShaderModuleBase> mShaderModule;
+
+ CompilationInfoEvent(const CompilationInfoCallbackInfo& callbackInfo,
+ Ref<ShaderModuleBase> shaderModule)
+ : TrackedEvent(callbackInfo.mode, TrackedEvent::Completed{}),
+ mCallback(callbackInfo.callback),
+ mUserdata(callbackInfo.userdata),
+ mShaderModule(std::move(shaderModule)) {
+ CompleteIfSpontaneous();
+ }
+
+ ~CompilationInfoEvent() override { EnsureComplete(EventCompletionType::Shutdown); }
+
+ void Complete(EventCompletionType completionType) override {
+ WGPUCompilationInfoRequestStatus status =
+ WGPUCompilationInfoRequestStatus_InstanceDropped;
+ const WGPUCompilationInfo* compilationInfo = nullptr;
+ if (completionType == EventCompletionType::Ready) {
+ status = WGPUCompilationInfoRequestStatus_Success;
+ compilationInfo = mShaderModule->mCompilationMessages->GetCompilationInfo();
+ }
+ if (mCallback) {
+ mCallback(status, compilationInfo, mUserdata);
+ }
+ }
+ };
+ FutureID futureID = GetDevice()->GetInstance()->GetEventManager()->TrackEvent(
+ AcquireRef(new CompilationInfoEvent(callbackInfo, this)));
+ return {futureID};
}
void ShaderModuleBase::InjectCompilationMessages(
diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h
index c31a16c..a186ac2 100644
--- a/src/dawn/native/ShaderModule.h
+++ b/src/dawn/native/ShaderModule.h
@@ -355,9 +355,9 @@
int GetTintProgramRecreateCountForTesting() const;
void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata);
+ Future APIGetCompilationInfoF(const CompilationInfoCallbackInfo& callbackInfo);
void InjectCompilationMessages(std::unique_ptr<OwnedCompilationMessages> compilationMessages);
-
OwnedCompilationMessages* GetCompilationMessages() const;
protected:
diff --git a/src/dawn/tests/unittests/wire/WireShaderModuleTests.cpp b/src/dawn/tests/unittests/wire/WireShaderModuleTests.cpp
index 9fc7ccc..495a9eb 100644
--- a/src/dawn/tests/unittests/wire/WireShaderModuleTests.cpp
+++ b/src/dawn/tests/unittests/wire/WireShaderModuleTests.cpp
@@ -27,6 +27,7 @@
#include <memory>
+#include "dawn/tests/unittests/wire/WireFutureTest.h"
#include "dawn/tests/unittests/wire/WireTest.h"
#include "dawn/wire/WireClient.h"
@@ -37,216 +38,195 @@
using testing::InvokeWithoutArgs;
using testing::Mock;
using testing::Return;
-using testing::StrictMock;
-// Mock class to add expectations on the wire calling callbacks
-class MockCompilationInfoCallback {
- public:
- MOCK_METHOD(void,
- Call,
- (WGPUCompilationInfoRequestStatus status,
- const WGPUCompilationInfo* info,
- void* userdata));
-};
-
-std::unique_ptr<StrictMock<MockCompilationInfoCallback>> mockCompilationInfoCallback;
-void ToMockGetCompilationInfoCallback(WGPUCompilationInfoRequestStatus status,
- const WGPUCompilationInfo* info,
- void* userdata) {
- mockCompilationInfoCallback->Call(status, info, userdata);
-}
-
-class WireShaderModuleTests : public WireTest {
- public:
- WireShaderModuleTests() {}
- ~WireShaderModuleTests() override = default;
+using WireShaderModuleTestBase = WireFutureTest<WGPUCompilationInfoCallback,
+ WGPUCompilationInfoCallbackInfo,
+ wgpuShaderModuleGetCompilationInfo,
+ wgpuShaderModuleGetCompilationInfoF>;
+class WireShaderModuleTests : public WireShaderModuleTestBase {
+ protected:
+ // Overriden version of wgpuShaderModuleGetCompilationInfo that defers to the API call based on
+ // the test callback mode.
+ void ShaderModuleGetCompilationInfo(WGPUShaderModule s, void* userdata = nullptr) {
+ CallImpl(userdata, s);
+ }
void SetUp() override {
- WireTest::SetUp();
-
- mockCompilationInfoCallback = std::make_unique<StrictMock<MockCompilationInfoCallback>>();
- apiShaderModule = api.GetNewShaderModule();
-
+ WireShaderModuleTestBase::SetUp();
WGPUShaderModuleDescriptor descriptor = {};
+ apiShaderModule = api.GetNewShaderModule();
shaderModule = wgpuDeviceCreateShaderModule(device, &descriptor);
-
EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _))
.WillOnce(Return(apiShaderModule))
.RetiresOnSaturation();
FlushClient();
}
- void TearDown() override {
- WireTest::TearDown();
-
- // Delete mock so that expectations are checked
- mockCompilationInfoCallback = nullptr;
- }
-
- void FlushClient() {
- WireTest::FlushClient();
- Mock::VerifyAndClearExpectations(&mockCompilationInfoCallback);
- }
-
- void FlushServer() {
- WireTest::FlushServer();
- Mock::VerifyAndClearExpectations(&mockCompilationInfoCallback);
- }
-
- protected:
WGPUShaderModule shaderModule;
WGPUShaderModule apiShaderModule;
+
+ // Default responses.
+ WGPUCompilationMessage mMessage = {
+ nullptr, "Test Message", WGPUCompilationMessageType_Info, 2, 4, 6, 8, 4, 6, 8};
+ WGPUCompilationInfo mCompilationInfo = {nullptr, 1, &mMessage};
};
-// Check getting CompilationInfo for a successfully created shader module
-TEST_F(WireShaderModuleTests, GetCompilationInfo) {
- wgpuShaderModuleGetCompilationInfo(shaderModule, ToMockGetCompilationInfoCallback, nullptr);
+DAWN_INSTANTIATE_WIRE_FUTURE_TEST_P(WireShaderModuleTests);
- WGPUCompilationMessage message = {
- nullptr, "Test Message", WGPUCompilationMessageType_Info, 2, 4, 6, 8, 4, 6, 8};
- WGPUCompilationInfo compilationInfo;
- compilationInfo.nextInChain = nullptr;
- compilationInfo.messageCount = 1;
- compilationInfo.messages = &message;
+// Check getting CompilationInfo for a successfully created shader module
+TEST_P(WireShaderModuleTests, GetCompilationInfo) {
+ ShaderModuleGetCompilationInfo(shaderModule);
EXPECT_CALL(api, OnShaderModuleGetCompilationInfo(apiShaderModule, _, _))
.WillOnce(InvokeWithoutArgs([&] {
api.CallShaderModuleGetCompilationInfoCallback(
- apiShaderModule, WGPUCompilationInfoRequestStatus_Success, &compilationInfo);
+ apiShaderModule, WGPUCompilationInfoRequestStatus_Success, &mCompilationInfo);
}));
-
FlushClient();
+ FlushFutures();
- EXPECT_CALL(*mockCompilationInfoCallback,
- Call(WGPUCompilationInfoRequestStatus_Success,
- MatchesLambda([&](const WGPUCompilationInfo* info) -> bool {
- if (info->messageCount != compilationInfo.messageCount) {
- return false;
- }
- const WGPUCompilationMessage* infoMessage = &info->messages[0];
- return strcmp(infoMessage->message, message.message) == 0 &&
- infoMessage->nextInChain == message.nextInChain &&
- infoMessage->type == message.type &&
- infoMessage->lineNum == message.lineNum &&
- infoMessage->linePos == message.linePos &&
- infoMessage->offset == message.offset &&
- infoMessage->length == message.length;
- }),
- _))
- .Times(1);
- FlushServer();
+ ExpectWireCallbacksWhen([&](auto& mockCb) {
+ EXPECT_CALL(mockCb, Call(WGPUCompilationInfoRequestStatus_Success,
+ MatchesLambda([&](const WGPUCompilationInfo* info) -> bool {
+ if (info->messageCount != mCompilationInfo.messageCount) {
+ return false;
+ }
+ const WGPUCompilationMessage* infoMessage = &info->messages[0];
+ return strcmp(infoMessage->message, mMessage.message) == 0 &&
+ infoMessage->nextInChain == mMessage.nextInChain &&
+ infoMessage->type == mMessage.type &&
+ infoMessage->lineNum == mMessage.lineNum &&
+ infoMessage->linePos == mMessage.linePos &&
+ infoMessage->offset == mMessage.offset &&
+ infoMessage->length == mMessage.length;
+ }),
+ nullptr))
+ .Times(1);
+
+ FlushCallbacks();
+ });
}
// Test that calling GetCompilationInfo then disconnecting the wire calls the callback with a
// device loss.
-TEST_F(WireShaderModuleTests, GetCompilationInfoBeforeDisconnect) {
- wgpuShaderModuleGetCompilationInfo(shaderModule, ToMockGetCompilationInfoCallback, nullptr);
-
- WGPUCompilationMessage message = {
- nullptr, "Test Message", WGPUCompilationMessageType_Info, 2, 4, 6, 8, 4, 6, 8};
- WGPUCompilationInfo compilationInfo;
- compilationInfo.nextInChain = nullptr;
- compilationInfo.messageCount = 1;
- compilationInfo.messages = &message;
+TEST_P(WireShaderModuleTests, GetCompilationInfoBeforeDisconnect) {
+ ShaderModuleGetCompilationInfo(shaderModule);
EXPECT_CALL(api, OnShaderModuleGetCompilationInfo(apiShaderModule, _, _))
.WillOnce(InvokeWithoutArgs([&] {
api.CallShaderModuleGetCompilationInfoCallback(
- apiShaderModule, WGPUCompilationInfoRequestStatus_Success, &compilationInfo);
+ apiShaderModule, WGPUCompilationInfoRequestStatus_Success, &mCompilationInfo);
}));
FlushClient();
+ FlushFutures();
- EXPECT_CALL(*mockCompilationInfoCallback,
- Call(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr, _));
- GetWireClient()->Disconnect();
+ ExpectWireCallbacksWhen([&](auto& mockCb) {
+ EXPECT_CALL(mockCb,
+ Call(WGPUCompilationInfoRequestStatus_InstanceDropped, nullptr, nullptr))
+ .Times(1);
+
+ GetWireClient()->Disconnect();
+ });
}
// Test that calling GetCompilationInfo after disconnecting the wire calls the callback with a
// device loss.
-TEST_F(WireShaderModuleTests, GetCompilationInfoAfterDisconnect) {
+TEST_P(WireShaderModuleTests, GetCompilationInfoAfterDisconnect) {
GetWireClient()->Disconnect();
- EXPECT_CALL(*mockCompilationInfoCallback,
- Call(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr, _));
- wgpuShaderModuleGetCompilationInfo(shaderModule, ToMockGetCompilationInfoCallback, nullptr);
-}
-// Hack to pass in test context into user callback
-struct TestData {
- WireShaderModuleTests* pTest;
- WGPUShaderModule* pTestShaderModule;
- size_t numRequests;
-};
+ ExpectWireCallbacksWhen([&](auto& mockCb) {
+ EXPECT_CALL(mockCb,
+ Call(WGPUCompilationInfoRequestStatus_InstanceDropped, nullptr, nullptr))
+ .Times(1);
-static void ToMockBufferMapCallbackWithNewRequests(WGPUCompilationInfoRequestStatus status,
- const WGPUCompilationInfo* info,
- void* userdata) {
- TestData* testData = reinterpret_cast<TestData*>(userdata);
- // Mimic the user callback is sending new requests
- ASSERT_NE(testData, nullptr);
- ASSERT_NE(testData->pTest, nullptr);
- ASSERT_NE(testData->pTestShaderModule, nullptr);
-
- mockCompilationInfoCallback->Call(status, info, testData->pTest);
-
- // Send the requests a number of times
- for (size_t i = 0; i < testData->numRequests; i++) {
- wgpuShaderModuleGetCompilationInfo(*(testData->pTestShaderModule),
- ToMockGetCompilationInfoCallback, nullptr);
- }
+ ShaderModuleGetCompilationInfo(shaderModule);
+ });
}
// Test that requests inside user callbacks before disconnect are called
-TEST_F(WireShaderModuleTests, GetCompilationInfoInsideCallbackBeforeDisconnect) {
- TestData testData = {this, &shaderModule, 10};
+TEST_P(WireShaderModuleTests, GetCompilationInfoInsideCallbackBeforeDisconnect) {
+ static constexpr size_t kNumRequests = 10;
- wgpuShaderModuleGetCompilationInfo(shaderModule, ToMockBufferMapCallbackWithNewRequests,
- &testData);
-
- WGPUCompilationMessage message = {
- nullptr, "Test Message", WGPUCompilationMessageType_Info, 2, 4, 6, 8, 4, 6, 8};
- WGPUCompilationInfo compilationInfo;
- compilationInfo.nextInChain = nullptr;
- compilationInfo.messageCount = 1;
- compilationInfo.messages = &message;
+ ShaderModuleGetCompilationInfo(shaderModule);
EXPECT_CALL(api, OnShaderModuleGetCompilationInfo(apiShaderModule, _, _))
.WillOnce(InvokeWithoutArgs([&] {
api.CallShaderModuleGetCompilationInfoCallback(
- apiShaderModule, WGPUCompilationInfoRequestStatus_Success, &compilationInfo);
+ apiShaderModule, WGPUCompilationInfoRequestStatus_Success, &mCompilationInfo);
}));
FlushClient();
+ FlushFutures();
- EXPECT_CALL(*mockCompilationInfoCallback,
- Call(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr, _))
- .Times(1 + testData.numRequests);
- GetWireClient()->Disconnect();
+ ExpectWireCallbacksWhen([&](auto& mockCb) {
+ EXPECT_CALL(mockCb,
+ Call(WGPUCompilationInfoRequestStatus_InstanceDropped, nullptr, nullptr))
+ .Times(kNumRequests + 1)
+ .WillOnce([&]() {
+ for (size_t i = 0; i < kNumRequests; i++) {
+ ShaderModuleGetCompilationInfo(shaderModule);
+ }
+ })
+ .WillRepeatedly(Return());
+
+ GetWireClient()->Disconnect();
+ });
}
// Test that requests inside user callbacks before object destruction are called
-TEST_F(WireShaderModuleTests, GetCompilationInfoInsideCallbackBeforeDestruction) {
- TestData testData = {this, &shaderModule, 10};
+TEST_P(WireShaderModuleTests, GetCompilationInfoInsideCallbackBeforeDestruction) {
+ static constexpr size_t kNumRequests = 10;
- wgpuShaderModuleGetCompilationInfo(shaderModule, ToMockBufferMapCallbackWithNewRequests,
- &testData);
-
- WGPUCompilationMessage message = {
- nullptr, "Test Message", WGPUCompilationMessageType_Info, 2, 4, 6, 8, 4, 6, 8};
- WGPUCompilationInfo compilationInfo;
- compilationInfo.nextInChain = nullptr;
- compilationInfo.messageCount = 1;
- compilationInfo.messages = &message;
+ ShaderModuleGetCompilationInfo(shaderModule);
EXPECT_CALL(api, OnShaderModuleGetCompilationInfo(apiShaderModule, _, _))
.WillOnce(InvokeWithoutArgs([&] {
api.CallShaderModuleGetCompilationInfoCallback(
- apiShaderModule, WGPUCompilationInfoRequestStatus_Success, &compilationInfo);
+ apiShaderModule, WGPUCompilationInfoRequestStatus_Success, &mCompilationInfo);
}));
FlushClient();
+ FlushFutures();
- EXPECT_CALL(*mockCompilationInfoCallback,
- Call(WGPUCompilationInfoRequestStatus_Unknown, nullptr, _))
- .Times(1 + testData.numRequests);
- wgpuShaderModuleRelease(shaderModule);
+ if (IsSpontaneous()) {
+ // In spontaneous mode, the callbacks can be fired immediately so they all happen when we
+ // flush the first callback.
+ ExpectWireCallbacksWhen([&](auto& mockCb) {
+ EXPECT_CALL(mockCb, Call(WGPUCompilationInfoRequestStatus_Success, _, nullptr))
+ .Times(kNumRequests + 1)
+ .WillOnce([&]() {
+ for (size_t i = 0; i < kNumRequests; i++) {
+ ShaderModuleGetCompilationInfo(shaderModule);
+ }
+ })
+ .WillRepeatedly(Return());
+
+ wgpuShaderModuleRelease(shaderModule);
+ FlushCallbacks();
+ });
+ } else {
+ // In non-spontaneous mode, we need to flush the client and callbacks again before the
+ // second round of callbacks are fired.
+ ExpectWireCallbacksWhen([&](auto& mockCb) {
+ EXPECT_CALL(mockCb, Call(WGPUCompilationInfoRequestStatus_Success, _, nullptr))
+ .WillOnce([&]() {
+ for (size_t i = 0; i < kNumRequests; i++) {
+ ShaderModuleGetCompilationInfo(shaderModule);
+ }
+ });
+
+ FlushCallbacks();
+ });
+
+ wgpuShaderModuleRelease(shaderModule);
+ FlushClient();
+ FlushFutures();
+ ExpectWireCallbacksWhen([&](auto& mockCb) {
+ EXPECT_CALL(mockCb, Call(WGPUCompilationInfoRequestStatus_Success, _, nullptr))
+ .Times(kNumRequests)
+ .WillRepeatedly(Return());
+
+ FlushCallbacks();
+ });
+ }
}
} // anonymous namespace
diff --git a/src/dawn/wire/client/ClientDoers.cpp b/src/dawn/wire/client/ClientDoers.cpp
index ba4064d..5dfdb7d 100644
--- a/src/dawn/wire/client/ClientDoers.cpp
+++ b/src/dawn/wire/client/ClientDoers.cpp
@@ -77,18 +77,4 @@
return WireResult::Success;
}
-WireResult Client::DoShaderModuleGetCompilationInfoCallback(ShaderModule* shaderModule,
- uint64_t requestSerial,
- WGPUCompilationInfoRequestStatus status,
- const WGPUCompilationInfo* info) {
- // The shader module might have been deleted or recreated so this isn't an error.
- if (shaderModule == nullptr) {
- return WireResult::Success;
- }
- if (shaderModule->GetCompilationInfoCallback(requestSerial, status, info)) {
- return WireResult::Success;
- }
- return WireResult::FatalError;
-}
-
} // namespace dawn::wire::client
diff --git a/src/dawn/wire/client/EventManager.cpp b/src/dawn/wire/client/EventManager.cpp
index 130a64e..f1f0f15 100644
--- a/src/dawn/wire/client/EventManager.cpp
+++ b/src/dawn/wire/client/EventManager.cpp
@@ -116,6 +116,8 @@
if (event->GetCallbackMode() != WGPUCallbackMode_AllowSpontaneous) {
events.emplace(it->first, std::move(event));
it = trackedEvents->erase(it);
+ } else {
+ ++it;
}
}
});
diff --git a/src/dawn/wire/client/EventManager.h b/src/dawn/wire/client/EventManager.h
index b3123ca..de30fb3 100644
--- a/src/dawn/wire/client/EventManager.h
+++ b/src/dawn/wire/client/EventManager.h
@@ -46,6 +46,7 @@
class Client;
enum class EventType {
+ CompilationInfo,
CreateComputePipeline,
CreateRenderPipeline,
MapAsync,
diff --git a/src/dawn/wire/client/ShaderModule.cpp b/src/dawn/wire/client/ShaderModule.cpp
index 9cc2196..fc65c08 100644
--- a/src/dawn/wire/client/ShaderModule.cpp
+++ b/src/dawn/wire/client/ShaderModule.cpp
@@ -27,56 +27,127 @@
#include "dawn/wire/client/ShaderModule.h"
+#include <memory>
+
#include "dawn/wire/client/Client.h"
+#include "partition_alloc/pointers/raw_ptr.h"
namespace dawn::wire::client {
-ShaderModule::~ShaderModule() {
- ClearAllCallbacks(WGPUCompilationInfoRequestStatus_Unknown);
-}
+class ShaderModule::CompilationInfoEvent final : public TrackedEvent {
+ public:
+ static constexpr EventType kType = EventType::CompilationInfo;
+
+ CompilationInfoEvent(const WGPUCompilationInfoCallbackInfo& callbackInfo, ShaderModule* shader)
+ : TrackedEvent(callbackInfo.mode),
+ mCallback(callbackInfo.callback),
+ mUserdata(callbackInfo.userdata),
+ mShader(shader) {
+ DAWN_ASSERT(mShader != nullptr);
+ mShader->Reference();
+ }
+
+ ~CompilationInfoEvent() override { mShader->Release(); }
+
+ EventType GetType() override { return kType; }
+
+ WireResult ReadyHook(FutureID futureId,
+ WGPUCompilationInfoRequestStatus status,
+ const WGPUCompilationInfo* info) {
+ if (mShader->mCompilationInfo) {
+ // If we already cached the compilation info on the shader, we don't need to do it
+ // again. This can happen if we were to call GetCompilationInfo multiple times before
+ // the wire flushes.
+ return ReadyHook(futureId);
+ }
+
+ mStatus = status;
+ mShader->mMessageStrings.reserve(info->messageCount);
+ mShader->mMessages.reserve(info->messageCount);
+ for (size_t i = 0; i < info->messageCount; i++) {
+ mShader->mMessageStrings.push_back(info->messages[i].message);
+ mShader->mMessages.push_back(info->messages[i]);
+ mShader->mMessages[i].message = mShader->mMessageStrings[i].c_str();
+ }
+ mShader->mCompilationInfo = {nullptr, mShader->mMessages.size(), mShader->mMessages.data()};
+
+ mCompilationInfo = &*mShader->mCompilationInfo;
+ return WireResult::Success;
+ }
+
+ WireResult ReadyHook(FutureID futureId) {
+ // We call this ReadyHook when we already have a cached compilation on the shader (usually
+ // from a previous GetCompilationInfo call).
+ DAWN_ASSERT(mShader->mCompilationInfo);
+ mStatus = WGPUCompilationInfoRequestStatus_Success;
+ mCompilationInfo = &(*mShader->mCompilationInfo);
+ return WireResult::Success;
+ }
+
+ private:
+ void CompleteImpl(FutureID futureID, EventCompletionType completionType) override {
+ if (completionType == EventCompletionType::Shutdown) {
+ mStatus = WGPUCompilationInfoRequestStatus_InstanceDropped;
+ mCompilationInfo = nullptr;
+ }
+ if (mCallback) {
+ mCallback(mStatus, mCompilationInfo, mUserdata);
+ }
+ }
+
+ WGPUCompilationInfoCallback mCallback;
+ // TODO(https://crbug.com/dawn/2345): Investigate `DanglingUntriaged` in dawn/wire.
+ raw_ptr<void, DanglingUntriaged> mUserdata;
+
+ WGPUCompilationInfoRequestStatus mStatus;
+ const WGPUCompilationInfo* mCompilationInfo = nullptr;
+
+ // Strong reference to the buffer so that when we call the callback we can pass the buffer.
+ // TODO(https://crbug.com/dawn/2345): Investigate `DanglingUntriaged` in dawn/wire.
+ const raw_ptr<ShaderModule, DanglingUntriaged> mShader;
+};
ObjectType ShaderModule::GetObjectType() const {
return ObjectType::ShaderModule;
}
void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) {
- Client* client = GetClient();
- if (client->IsDisconnected()) {
- callback(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr, userdata);
- return;
+ WGPUCompilationInfoCallbackInfo callbackInfo = {};
+ callbackInfo.mode = WGPUCallbackMode_AllowSpontaneous;
+ callbackInfo.callback = callback;
+ callbackInfo.userdata = userdata;
+ GetCompilationInfoF(callbackInfo);
+}
+
+WGPUFuture ShaderModule::GetCompilationInfoF(const WGPUCompilationInfoCallbackInfo& callbackInfo) {
+ auto [futureIDInternal, tracked] =
+ GetEventManager().TrackEvent(std::make_unique<CompilationInfoEvent>(callbackInfo, this));
+ if (!tracked) {
+ return {futureIDInternal};
}
- uint64_t serial = mCompilationInfoRequests.Add({callback, userdata});
+ // If we already have a cached compilation info object, we can set it ready now.
+ if (mCompilationInfo) {
+ DAWN_CHECK(GetEventManager().SetFutureReady<CompilationInfoEvent>(futureIDInternal) ==
+ WireResult::Success);
+ return {futureIDInternal};
+ }
ShaderModuleGetCompilationInfoCmd cmd;
cmd.shaderModuleId = GetWireId();
- cmd.requestSerial = serial;
+ cmd.eventManagerHandle = GetEventManagerHandle();
+ cmd.future = {futureIDInternal};
- client->SerializeCommand(cmd);
+ GetClient()->SerializeCommand(cmd);
+ return {futureIDInternal};
}
-bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial,
- WGPUCompilationInfoRequestStatus status,
- const WGPUCompilationInfo* info) {
- CompilationInfoRequest request;
- if (!mCompilationInfoRequests.Acquire(requestSerial, &request)) {
- return false;
- }
-
- request.callback(status, info, request.userdata);
- return true;
-}
-
-void ShaderModule::CancelCallbacksForDisconnect() {
- ClearAllCallbacks(WGPUCompilationInfoRequestStatus_DeviceLost);
-}
-
-void ShaderModule::ClearAllCallbacks(WGPUCompilationInfoRequestStatus status) {
- mCompilationInfoRequests.CloseAll([status](CompilationInfoRequest* request) {
- if (request->callback != nullptr) {
- request->callback(status, nullptr, request->userdata);
- }
- });
+WireResult Client::DoShaderModuleGetCompilationInfoCallback(ObjectHandle eventManager,
+ WGPUFuture future,
+ WGPUCompilationInfoRequestStatus status,
+ const WGPUCompilationInfo* info) {
+ return GetEventManager(eventManager)
+ .SetFutureReady<ShaderModule::CompilationInfoEvent>(future.id, status, info);
}
} // namespace dawn::wire::client
diff --git a/src/dawn/wire/client/ShaderModule.h b/src/dawn/wire/client/ShaderModule.h
index 8708ef9..9b6bb26 100644
--- a/src/dawn/wire/client/ShaderModule.h
+++ b/src/dawn/wire/client/ShaderModule.h
@@ -28,36 +28,32 @@
#ifndef SRC_DAWN_WIRE_CLIENT_SHADERMODULE_H_
#define SRC_DAWN_WIRE_CLIENT_SHADERMODULE_H_
+#include <optional>
+#include <string>
+#include <vector>
+
#include "dawn/webgpu.h"
-#include "partition_alloc/pointers/raw_ptr.h"
#include "dawn/wire/client/ObjectBase.h"
-#include "dawn/wire/client/RequestTracker.h"
namespace dawn::wire::client {
-class ShaderModule final : public ObjectBase {
+class ShaderModule final : public ObjectWithEventsBase {
public:
- using ObjectBase::ObjectBase;
- ~ShaderModule() override;
+ using ObjectWithEventsBase::ObjectWithEventsBase;
ObjectType GetObjectType() const override;
void GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata);
- bool GetCompilationInfoCallback(uint64_t requestSerial,
- WGPUCompilationInfoRequestStatus status,
- const WGPUCompilationInfo* info);
+ WGPUFuture GetCompilationInfoF(const WGPUCompilationInfoCallbackInfo& callbackInfo);
private:
- void CancelCallbacksForDisconnect() override;
- void ClearAllCallbacks(WGPUCompilationInfoRequestStatus status);
+ friend class Client;
+ class CompilationInfoEvent;
- struct CompilationInfoRequest {
- WGPUCompilationInfoCallback callback = nullptr;
- // TODO(https://crbug.com/dawn/2345): Investigate `DanglingUntriaged` in dawn/wire.
- raw_ptr<void, DanglingUntriaged> userdata = nullptr;
- };
- RequestTracker<CompilationInfoRequest> mCompilationInfoRequests;
+ std::optional<WGPUCompilationInfo> mCompilationInfo;
+ std::vector<std::string> mMessageStrings;
+ std::vector<WGPUCompilationMessage> mMessages;
};
} // namespace dawn::wire::client
diff --git a/src/dawn/wire/server/Server.h b/src/dawn/wire/server/Server.h
index e3ee61c..44795df 100644
--- a/src/dawn/wire/server/Server.h
+++ b/src/dawn/wire/server/Server.h
@@ -125,8 +125,8 @@
struct ShaderModuleGetCompilationInfoUserdata : CallbackUserdata {
using CallbackUserdata::CallbackUserdata;
- ObjectHandle shaderModule;
- uint64_t requestSerial;
+ ObjectHandle eventManager;
+ WGPUFuture future;
};
struct QueueWorkDoneUserdata : CallbackUserdata {
diff --git a/src/dawn/wire/server/ServerShaderModule.cpp b/src/dawn/wire/server/ServerShaderModule.cpp
index 971f734..21c28aa 100644
--- a/src/dawn/wire/server/ServerShaderModule.cpp
+++ b/src/dawn/wire/server/ServerShaderModule.cpp
@@ -32,10 +32,11 @@
namespace dawn::wire::server {
WireResult Server::DoShaderModuleGetCompilationInfo(Known<WGPUShaderModule> shaderModule,
- uint64_t requestSerial) {
+ ObjectHandle eventManager,
+ WGPUFuture future) {
auto userdata = MakeUserdata<ShaderModuleGetCompilationInfoUserdata>();
- userdata->shaderModule = shaderModule.AsHandle();
- userdata->requestSerial = requestSerial;
+ userdata->eventManager = eventManager;
+ userdata->future = future;
mProcs.shaderModuleGetCompilationInfo(
shaderModule->handle, ForwardToServer<&Server::OnShaderModuleGetCompilationInfo>,
@@ -47,8 +48,8 @@
WGPUCompilationInfoRequestStatus status,
const WGPUCompilationInfo* info) {
ReturnShaderModuleGetCompilationInfoCallbackCmd cmd;
- cmd.shaderModule = data->shaderModule;
- cmd.requestSerial = data->requestSerial;
+ cmd.eventManager = data->eventManager;
+ cmd.future = data->future;
cmd.status = status;
cmd.info = info;