Add a per-thread proc table using thread local storage
In situations where both dawn_wire and dawn_native are used on separate
threads (Chrome with --single-process or --in-process-gpu), it's
desirable to have a per-thread proc table so that the WebGPU C++ API can
still be used. This eliminates classes of bugs with manual
reference/release errors.
This also changes many of the GetProcs functions to return const
references to the static proc tables known at compile time, instead of a
copy.
Bug: none
Change-Id: I8775bb715b312dd9476a1903fbd797d4b1302614
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/29240
Reviewed-by: Stephen White <senorblanco@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py
index b7d9003..02c5723 100644
--- a/generator/dawn_json_generator.py
+++ b/generator/dawn_json_generator.py
@@ -687,6 +687,10 @@
renders.append(
FileRender('dawn_proc.c', 'src/dawn/dawn_proc.c',
[base_params, api_params]))
+ renders.append(
+ FileRender('dawn_thread_dispatch_proc.cpp',
+ 'src/dawn/dawn_thread_dispatch_proc.cpp',
+ [base_params, api_params]))
if 'dawncpp' in targets:
renders.append(
diff --git a/generator/templates/dawn_native/ProcTable.cpp b/generator/templates/dawn_native/ProcTable.cpp
index 88b780c..defae2d 100644
--- a/generator/templates/dawn_native/ProcTable.cpp
+++ b/generator/templates/dawn_native/ProcTable.cpp
@@ -135,16 +135,17 @@
return result;
}
- DawnProcTable GetProcsAutogen() {
- DawnProcTable table;
- table.getProcAddress = NativeGetProcAddress;
- table.createInstance = NativeCreateInstance;
+ static DawnProcTable gProcTable = {
+ NativeGetProcAddress,
+ NativeCreateInstance,
{% for type in by_category["object"] %}
{% for method in c_methods(type) %}
- table.{{as_varName(type.name, method.name)}} = Native{{as_MethodSuffix(type.name, method.name)}};
+ Native{{as_MethodSuffix(type.name, method.name)}},
{% endfor %}
{% endfor %}
- return table;
- }
+ };
+ const DawnProcTable& GetProcsAutogen() {
+ return gProcTable;
+ }
}
diff --git a/generator/templates/dawn_proc_table.h b/generator/templates/dawn_proc_table.h
index 197f300..1da1f73 100644
--- a/generator/templates/dawn_proc_table.h
+++ b/generator/templates/dawn_proc_table.h
@@ -17,6 +17,7 @@
#include "dawn/webgpu.h"
+// Note: Often allocated as a static global. Do not add a complex constructor.
typedef struct DawnProcTable {
WGPUProcGetProcAddress getProcAddress;
WGPUProcCreateInstance createInstance;
diff --git a/generator/templates/dawn_thread_dispatch_proc.cpp b/generator/templates/dawn_thread_dispatch_proc.cpp
new file mode 100644
index 0000000..bfc7794
--- /dev/null
+++ b/generator/templates/dawn_thread_dispatch_proc.cpp
@@ -0,0 +1,52 @@
+#include "dawn/dawn_thread_dispatch_proc.h"
+
+#include <thread>
+
+static DawnProcTable nullProcs;
+thread_local DawnProcTable perThreadProcs;
+
+void dawnProcSetPerThreadProcs(const DawnProcTable* procs) {
+ if (procs) {
+ perThreadProcs = *procs;
+ } else {
+ perThreadProcs = nullProcs;
+ }
+}
+
+static WGPUProc ThreadDispatchGetProcAddress(WGPUDevice device, const char* procName) {
+ return perThreadProcs.getProcAddress(device, procName);
+}
+
+static WGPUInstance ThreadDispatchCreateInstance(WGPUInstanceDescriptor const * descriptor) {
+ return perThreadProcs.createInstance(descriptor);
+}
+
+{% for type in by_category["object"] %}
+ {% for method in c_methods(type) %}
+ static {{as_cType(method.return_type.name)}} ThreadDispatch{{as_MethodSuffix(type.name, method.name)}}(
+ {{-as_cType(type.name)}} {{as_varName(type.name)}}
+ {%- for arg in method.arguments -%}
+ , {{as_annotated_cType(arg)}}
+ {%- endfor -%}
+ ) {
+ {% if method.return_type.name.canonical_case() != "void" %}return {% endif %}
+ perThreadProcs.{{as_varName(type.name, method.name)}}({{as_varName(type.name)}}
+ {%- for arg in method.arguments -%}
+ , {{as_varName(arg.name)}}
+ {%- endfor -%}
+ );
+ }
+ {% endfor %}
+{% endfor %}
+
+extern "C" {
+ DawnProcTable dawnThreadDispatchProcTable = {
+ ThreadDispatchGetProcAddress,
+ ThreadDispatchCreateInstance,
+{% for type in by_category["object"] %}
+ {% for method in c_methods(type) %}
+ ThreadDispatch{{as_MethodSuffix(type.name, method.name)}},
+ {% endfor %}
+{% endfor %}
+ };
+}
diff --git a/generator/templates/dawn_wire/client/ApiProcs.cpp b/generator/templates/dawn_wire/client/ApiProcs.cpp
index 3edba1a..0ad9a77 100644
--- a/generator/templates/dawn_wire/client/ApiProcs.cpp
+++ b/generator/templates/dawn_wire/client/ApiProcs.cpp
@@ -293,21 +293,16 @@
return result;
}
- //* Some commands don't have a custom wire format, but need to be handled manually to update
- //* some client-side state tracking. For these we have two functions:
- //* - An autogenerated Client{{suffix}} method that sends the command on the wire
- //* - A manual ProxyClient{{suffix}} method that will be inserted in the proctable instead of
- //* the autogenerated one, and that will have to call Client{{suffix}}
- DawnProcTable GetProcs() {
- DawnProcTable table;
- table.getProcAddress = ClientGetProcAddress;
- table.createInstance = ClientCreateInstance;
+ static DawnProcTable gProcTable = {
+ ClientGetProcAddress,
+ ClientCreateInstance,
{% for type in by_category["object"] %}
{% for method in c_methods(type) %}
- {% set suffix = as_MethodSuffix(type.name, method.name) %}
- table.{{as_varName(type.name, method.name)}} = Client{{suffix}};
+ Client{{as_MethodSuffix(type.name, method.name)}},
{% endfor %}
{% endfor %}
- return table;
+ };
+ const DawnProcTable& GetProcs() {
+ return gProcTable;
}
}} // namespace dawn_wire::client
diff --git a/src/dawn/BUILD.gn b/src/dawn/BUILD.gn
index 9034be4..ad48712 100644
--- a/src/dawn/BUILD.gn
+++ b/src/dawn/BUILD.gn
@@ -87,7 +87,10 @@
dawn_json_generator("dawn_proc_gen") {
target = "dawn_proc"
- outputs = [ "src/dawn/dawn_proc.c" ]
+ outputs = [
+ "src/dawn/dawn_proc.c",
+ "src/dawn/dawn_thread_dispatch_proc.cpp",
+ ]
}
dawn_component("dawn_proc") {
@@ -96,5 +99,8 @@
public_deps = [ ":dawn_headers" ]
deps = [ ":dawn_proc_gen" ]
sources = get_target_outputs(":dawn_proc_gen")
- sources += [ "${dawn_root}/src/include/dawn/dawn_proc.h" ]
+ sources += [
+ "${dawn_root}/src/include/dawn/dawn_proc.h",
+ "${dawn_root}/src/include/dawn/dawn_thread_dispatch_proc.h",
+ ]
}
diff --git a/src/dawn_native/DawnNative.cpp b/src/dawn_native/DawnNative.cpp
index bfc47db..52efa18 100644
--- a/src/dawn_native/DawnNative.cpp
+++ b/src/dawn_native/DawnNative.cpp
@@ -22,9 +22,9 @@
namespace dawn_native {
- DawnProcTable GetProcsAutogen();
+ const DawnProcTable& GetProcsAutogen();
- DawnProcTable GetProcs() {
+ const DawnProcTable& GetProcs() {
return GetProcsAutogen();
}
diff --git a/src/dawn_wire/WireClient.cpp b/src/dawn_wire/WireClient.cpp
index e6fe263..430a55c 100644
--- a/src/dawn_wire/WireClient.cpp
+++ b/src/dawn_wire/WireClient.cpp
@@ -26,7 +26,7 @@
}
// static
- DawnProcTable WireClient::GetProcs() {
+ const DawnProcTable& WireClient::GetProcs() {
return client::GetProcs();
}
diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h
index d8df86d..be70e75 100644
--- a/src/dawn_wire/client/Client.h
+++ b/src/dawn_wire/client/Client.h
@@ -68,7 +68,7 @@
bool mIsDisconnected = false;
};
- DawnProcTable GetProcs();
+ const DawnProcTable& GetProcs();
std::unique_ptr<MemoryTransferService> CreateInlineMemoryTransferService();
diff --git a/src/include/dawn/dawn_thread_dispatch_proc.h b/src/include/dawn/dawn_thread_dispatch_proc.h
new file mode 100644
index 0000000..4d08ba8
--- /dev/null
+++ b/src/include/dawn/dawn_thread_dispatch_proc.h
@@ -0,0 +1,33 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWN_DAWN_THREAD_DISPATCH_PROC_H_
+#define DAWN_DAWN_THREAD_DISPATCH_PROC_H_
+
+#include "dawn/dawn_proc.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Call dawnProcSetProcs(&dawnThreadDispatchProcTable) and then use dawnProcSetPerThreadProcs
+// to set per-thread procs.
+WGPU_EXPORT extern DawnProcTable dawnThreadDispatchProcTable;
+WGPU_EXPORT void dawnProcSetPerThreadProcs(const DawnProcTable* procs);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // DAWN_DAWN_THREAD_DISPATCH_PROC_H_
diff --git a/src/include/dawn_native/DawnNative.h b/src/include/dawn_native/DawnNative.h
index 2199efa..6498646 100644
--- a/src/include/dawn_native/DawnNative.h
+++ b/src/include/dawn_native/DawnNative.h
@@ -170,7 +170,7 @@
};
// Backend-agnostic API for dawn_native
- DAWN_NATIVE_EXPORT DawnProcTable GetProcs();
+ DAWN_NATIVE_EXPORT const DawnProcTable& GetProcs();
// Query the names of all the toggles that are enabled in device
DAWN_NATIVE_EXPORT std::vector<const char*> GetTogglesUsed(WGPUDevice device);
diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h
index 815b66b..50da913 100644
--- a/src/include/dawn_wire/WireClient.h
+++ b/src/include/dawn_wire/WireClient.h
@@ -26,6 +26,8 @@
namespace client {
class Client;
class MemoryTransferService;
+
+ DAWN_WIRE_EXPORT const DawnProcTable& GetProcs();
} // namespace client
struct ReservedTexture {
@@ -44,7 +46,8 @@
WireClient(const WireClientDescriptor& descriptor);
~WireClient() override;
- static DawnProcTable GetProcs();
+ // TODO(enga): Remove this and use dawn_wire::client::GetProcs() instead
+ static const DawnProcTable& GetProcs();
WGPUDevice GetDevice() const;
const volatile char* HandleCommands(const volatile char* commands,
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index 0077695..9c998f5 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -168,6 +168,7 @@
"unittests/MathTests.cpp",
"unittests/ObjectBaseTests.cpp",
"unittests/PerStageTests.cpp",
+ "unittests/PerThreadProcTests.cpp",
"unittests/PlacementAllocatedTests.cpp",
"unittests/RefCountedTests.cpp",
"unittests/ResultTests.cpp",
diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp
index 5873da8..945d9c1 100644
--- a/src/tests/DawnTest.cpp
+++ b/src/tests/DawnTest.cpp
@@ -744,12 +744,9 @@
clientDesc.serializer = mC2sBuf.get();
mWireClient.reset(new dawn_wire::WireClient(clientDesc));
- WGPUDevice clientDevice = mWireClient->GetDevice();
- DawnProcTable clientProcs = dawn_wire::WireClient::GetProcs();
+ cDevice = mWireClient->GetDevice();
+ procs = dawn_wire::client::GetProcs();
mS2cBuf->SetHandler(mWireClient.get());
-
- procs = clientProcs;
- cDevice = clientDevice;
} else {
procs = backendProcs;
cDevice = backendDevice;
diff --git a/src/tests/end2end/WindowSurfaceTests.cpp b/src/tests/end2end/WindowSurfaceTests.cpp
index 8aaa8cc..03e954e 100644
--- a/src/tests/end2end/WindowSurfaceTests.cpp
+++ b/src/tests/end2end/WindowSurfaceTests.cpp
@@ -50,8 +50,7 @@
});
DAWN_SKIP_TEST_IF(!glfwInit());
- DawnProcTable procs = dawn_native::GetProcs();
- dawnProcSetProcs(&procs);
+ dawnProcSetProcs(&dawn_native::GetProcs());
mInstance = wgpu::CreateInstance();
}
diff --git a/src/tests/unittests/GetProcAddressTests.cpp b/src/tests/unittests/GetProcAddressTests.cpp
index f5ac8c6..10f9c5c 100644
--- a/src/tests/unittests/GetProcAddressTests.cpp
+++ b/src/tests/unittests/GetProcAddressTests.cpp
@@ -72,7 +72,7 @@
mWireClient = std::make_unique<dawn_wire::WireClient>(clientDesc);
mDevice = wgpu::Device::Acquire(mWireClient->GetDevice());
- mProcs = dawn_wire::WireClient::GetProcs();
+ mProcs = dawn_wire::client::GetProcs();
break;
}
diff --git a/src/tests/unittests/PerThreadProcTests.cpp b/src/tests/unittests/PerThreadProcTests.cpp
new file mode 100644
index 0000000..38ce981
--- /dev/null
+++ b/src/tests/unittests/PerThreadProcTests.cpp
@@ -0,0 +1,118 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/dawn_thread_dispatch_proc.h"
+#include "dawn/webgpu_cpp.h"
+#include "dawn_native/DawnNative.h"
+#include "dawn_native/Instance.h"
+#include "dawn_native/null/DeviceNull.h"
+
+#include <gtest/gtest.h>
+#include <atomic>
+#include <thread>
+
+class PerThreadProcTests : public testing::Test {
+ public:
+ PerThreadProcTests()
+ : mNativeInstance(dawn_native::InstanceBase::Create()),
+ mNativeAdapter(mNativeInstance.Get()) {
+ }
+ ~PerThreadProcTests() override = default;
+
+ protected:
+ Ref<dawn_native::InstanceBase> mNativeInstance;
+ dawn_native::null::Adapter mNativeAdapter;
+};
+
+// Test that procs can be set per thread. This test overrides deviceCreateBuffer with a dummy proc
+// for each thread that increments a counter. Because each thread has their own proc and counter,
+// there should be no data races. The per-thread procs also check that the current thread id is
+// exactly equal to the expected thread id.
+TEST_F(PerThreadProcTests, DispatchesPerThread) {
+ dawnProcSetProcs(&dawnThreadDispatchProcTable);
+
+ // Threads will block on this atomic to be sure we set procs on both threads before
+ // either thread calls the procs.
+ std::atomic<bool> ready(false);
+
+ static int threadACounter = 0;
+ static int threadBCounter = 0;
+
+ static std::atomic<std::thread::id> threadIdA;
+ static std::atomic<std::thread::id> threadIdB;
+
+ constexpr int kThreadATargetCount = 28347;
+ constexpr int kThreadBTargetCount = 40420;
+
+ // Note: Acquire doesn't call reference or release.
+ wgpu::Device deviceA =
+ wgpu::Device::Acquire(reinterpret_cast<WGPUDevice>(mNativeAdapter.CreateDevice(nullptr)));
+
+ wgpu::Device deviceB =
+ wgpu::Device::Acquire(reinterpret_cast<WGPUDevice>(mNativeAdapter.CreateDevice(nullptr)));
+
+ std::thread threadA([&]() {
+ DawnProcTable procs = dawn_native::GetProcs();
+ procs.deviceCreateBuffer = [](WGPUDevice device,
+ WGPUBufferDescriptor const* descriptor) -> WGPUBuffer {
+ EXPECT_EQ(std::this_thread::get_id(), threadIdA);
+ threadACounter++;
+ return nullptr;
+ };
+ dawnProcSetPerThreadProcs(&procs);
+
+ while (!ready) {
+ } // Should be fast, so just spin.
+
+ for (int i = 0; i < kThreadATargetCount; ++i) {
+ deviceA.CreateBuffer(nullptr);
+ }
+
+ deviceA = nullptr;
+ dawnProcSetPerThreadProcs(nullptr);
+ });
+
+ std::thread threadB([&]() {
+ DawnProcTable procs = dawn_native::GetProcs();
+ procs.deviceCreateBuffer = [](WGPUDevice device,
+ WGPUBufferDescriptor const* bufferDesc) -> WGPUBuffer {
+ EXPECT_EQ(std::this_thread::get_id(), threadIdB);
+ threadBCounter++;
+ return nullptr;
+ };
+ dawnProcSetPerThreadProcs(&procs);
+
+ while (!ready) {
+ } // Should be fast, so just spin.
+
+ for (int i = 0; i < kThreadBTargetCount; ++i) {
+ deviceB.CreateBuffer(nullptr);
+ }
+
+ deviceB = nullptr;
+ dawnProcSetPerThreadProcs(nullptr);
+ });
+
+ threadIdA = threadA.get_id();
+ threadIdB = threadB.get_id();
+
+ ready = true;
+ threadA.join();
+ threadB.join();
+
+ EXPECT_EQ(threadACounter, kThreadATargetCount);
+ EXPECT_EQ(threadBCounter, kThreadBTargetCount);
+
+ dawnProcSetProcs(nullptr);
+}
diff --git a/src/tests/unittests/validation/ValidationTest.cpp b/src/tests/unittests/validation/ValidationTest.cpp
index 058fd3b..1713171 100644
--- a/src/tests/unittests/validation/ValidationTest.cpp
+++ b/src/tests/unittests/validation/ValidationTest.cpp
@@ -40,8 +40,7 @@
ASSERT(foundNullAdapter);
- DawnProcTable procs = dawn_native::GetProcs();
- dawnProcSetProcs(&procs);
+ dawnProcSetProcs(&dawn_native::GetProcs());
device = CreateDeviceFromAdapter(adapter, std::vector<const char*>());
}
diff --git a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
index 3ba0da0..b442794 100644
--- a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
+++ b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
@@ -29,8 +29,7 @@
class WireMultipleDeviceTests : public testing::Test {
protected:
void SetUp() override {
- DawnProcTable procs = dawn_wire::WireClient::GetProcs();
- dawnProcSetProcs(&procs);
+ dawnProcSetProcs(&dawn_wire::client::GetProcs());
}
void TearDown() override {
diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp
index bdac99f..d23709c 100644
--- a/src/tests/unittests/wire/WireTest.cpp
+++ b/src/tests/unittests/wire/WireTest.cpp
@@ -66,8 +66,7 @@
mS2cBuf->SetHandler(mWireClient.get());
device = mWireClient->GetDevice();
- DawnProcTable clientProcs = dawn_wire::WireClient::GetProcs();
- dawnProcSetProcs(&clientProcs);
+ dawnProcSetProcs(&dawn_wire::client::GetProcs());
apiDevice = mockDevice;