Add a per-thread proc table using thread local storage

In situations where both dawn_wire and dawn_native are used on separate
threads (Chrome with --single-process or --in-process-gpu), it's
desirable to have a per-thread proc table so that the WebGPU C++ API can
still be used. This eliminates classes of bugs with manual
reference/release errors.

This also changes many of the GetProcs functions to return const
references to the static proc tables known at compile time, instead of a
copy.

Bug: none
Change-Id: I8775bb715b312dd9476a1903fbd797d4b1302614
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/29240
Reviewed-by: Stephen White <senorblanco@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py
index b7d9003..02c5723 100644
--- a/generator/dawn_json_generator.py
+++ b/generator/dawn_json_generator.py
@@ -687,6 +687,10 @@
             renders.append(
                 FileRender('dawn_proc.c', 'src/dawn/dawn_proc.c',
                            [base_params, api_params]))
+            renders.append(
+                FileRender('dawn_thread_dispatch_proc.cpp',
+                           'src/dawn/dawn_thread_dispatch_proc.cpp',
+                           [base_params, api_params]))
 
         if 'dawncpp' in targets:
             renders.append(
diff --git a/generator/templates/dawn_native/ProcTable.cpp b/generator/templates/dawn_native/ProcTable.cpp
index 88b780c..defae2d 100644
--- a/generator/templates/dawn_native/ProcTable.cpp
+++ b/generator/templates/dawn_native/ProcTable.cpp
@@ -135,16 +135,17 @@
         return result;
     }
 
-    DawnProcTable GetProcsAutogen() {
-        DawnProcTable table;
-        table.getProcAddress = NativeGetProcAddress;
-        table.createInstance = NativeCreateInstance;
+    static DawnProcTable gProcTable = {
+        NativeGetProcAddress,
+        NativeCreateInstance,
         {% for type in by_category["object"] %}
             {% for method in c_methods(type) %}
-                table.{{as_varName(type.name, method.name)}} = Native{{as_MethodSuffix(type.name, method.name)}};
+                Native{{as_MethodSuffix(type.name, method.name)}},
             {% endfor %}
         {% endfor %}
-        return table;
-    }
+    };
 
+    const DawnProcTable& GetProcsAutogen() {
+        return gProcTable;
+    }
 }
diff --git a/generator/templates/dawn_proc_table.h b/generator/templates/dawn_proc_table.h
index 197f300..1da1f73 100644
--- a/generator/templates/dawn_proc_table.h
+++ b/generator/templates/dawn_proc_table.h
@@ -17,6 +17,7 @@
 
 #include "dawn/webgpu.h"
 
+// Note: Often allocated as a static global. Do not add a complex constructor.
 typedef struct DawnProcTable {
     WGPUProcGetProcAddress getProcAddress;
     WGPUProcCreateInstance createInstance;
diff --git a/generator/templates/dawn_thread_dispatch_proc.cpp b/generator/templates/dawn_thread_dispatch_proc.cpp
new file mode 100644
index 0000000..bfc7794
--- /dev/null
+++ b/generator/templates/dawn_thread_dispatch_proc.cpp
@@ -0,0 +1,52 @@
+#include "dawn/dawn_thread_dispatch_proc.h"
+
+#include <thread>
+
+static DawnProcTable nullProcs;
+thread_local DawnProcTable perThreadProcs;
+
+void dawnProcSetPerThreadProcs(const DawnProcTable* procs) {
+    if (procs) {
+        perThreadProcs = *procs;
+    } else {
+        perThreadProcs = nullProcs;
+    }
+}
+
+static WGPUProc ThreadDispatchGetProcAddress(WGPUDevice device, const char* procName) {
+    return perThreadProcs.getProcAddress(device, procName);
+}
+
+static WGPUInstance ThreadDispatchCreateInstance(WGPUInstanceDescriptor const * descriptor) {
+    return perThreadProcs.createInstance(descriptor);
+}
+
+{% for type in by_category["object"] %}
+    {% for method in c_methods(type) %}
+        static {{as_cType(method.return_type.name)}} ThreadDispatch{{as_MethodSuffix(type.name, method.name)}}(
+            {{-as_cType(type.name)}} {{as_varName(type.name)}}
+            {%- for arg in method.arguments -%}
+                , {{as_annotated_cType(arg)}}
+            {%- endfor -%}
+        ) {
+            {% if method.return_type.name.canonical_case() != "void" %}return {% endif %}
+            perThreadProcs.{{as_varName(type.name, method.name)}}({{as_varName(type.name)}}
+                {%- for arg in method.arguments -%}
+                    , {{as_varName(arg.name)}}
+                {%- endfor -%}
+            );
+        }
+    {% endfor %}
+{% endfor %}
+
+extern "C" {
+    DawnProcTable dawnThreadDispatchProcTable = {
+        ThreadDispatchGetProcAddress,
+        ThreadDispatchCreateInstance,
+{% for type in by_category["object"] %}
+    {% for method in c_methods(type) %}
+        ThreadDispatch{{as_MethodSuffix(type.name, method.name)}},
+    {% endfor %}
+{% endfor %}
+    };
+}
diff --git a/generator/templates/dawn_wire/client/ApiProcs.cpp b/generator/templates/dawn_wire/client/ApiProcs.cpp
index 3edba1a..0ad9a77 100644
--- a/generator/templates/dawn_wire/client/ApiProcs.cpp
+++ b/generator/templates/dawn_wire/client/ApiProcs.cpp
@@ -293,21 +293,16 @@
         return result;
     }
 
-    //* Some commands don't have a custom wire format, but need to be handled manually to update
-    //* some client-side state tracking. For these we have two functions:
-    //*  - An autogenerated Client{{suffix}} method that sends the command on the wire
-    //*  - A manual ProxyClient{{suffix}} method that will be inserted in the proctable instead of
-    //*    the autogenerated one, and that will have to call Client{{suffix}}
-    DawnProcTable GetProcs() {
-        DawnProcTable table;
-        table.getProcAddress = ClientGetProcAddress;
-        table.createInstance = ClientCreateInstance;
+    static DawnProcTable gProcTable = {
+        ClientGetProcAddress,
+        ClientCreateInstance,
         {% for type in by_category["object"] %}
             {% for method in c_methods(type) %}
-                {% set suffix = as_MethodSuffix(type.name, method.name) %}
-                table.{{as_varName(type.name, method.name)}} = Client{{suffix}};
+                Client{{as_MethodSuffix(type.name, method.name)}},
             {% endfor %}
         {% endfor %}
-        return table;
+    };
+    const DawnProcTable& GetProcs() {
+        return gProcTable;
     }
 }}  // namespace dawn_wire::client
diff --git a/src/dawn/BUILD.gn b/src/dawn/BUILD.gn
index 9034be4..ad48712 100644
--- a/src/dawn/BUILD.gn
+++ b/src/dawn/BUILD.gn
@@ -87,7 +87,10 @@
 
 dawn_json_generator("dawn_proc_gen") {
   target = "dawn_proc"
-  outputs = [ "src/dawn/dawn_proc.c" ]
+  outputs = [
+    "src/dawn/dawn_proc.c",
+    "src/dawn/dawn_thread_dispatch_proc.cpp",
+  ]
 }
 
 dawn_component("dawn_proc") {
@@ -96,5 +99,8 @@
   public_deps = [ ":dawn_headers" ]
   deps = [ ":dawn_proc_gen" ]
   sources = get_target_outputs(":dawn_proc_gen")
-  sources += [ "${dawn_root}/src/include/dawn/dawn_proc.h" ]
+  sources += [
+    "${dawn_root}/src/include/dawn/dawn_proc.h",
+    "${dawn_root}/src/include/dawn/dawn_thread_dispatch_proc.h",
+  ]
 }
diff --git a/src/dawn_native/DawnNative.cpp b/src/dawn_native/DawnNative.cpp
index bfc47db..52efa18 100644
--- a/src/dawn_native/DawnNative.cpp
+++ b/src/dawn_native/DawnNative.cpp
@@ -22,9 +22,9 @@
 
 namespace dawn_native {
 
-    DawnProcTable GetProcsAutogen();
+    const DawnProcTable& GetProcsAutogen();
 
-    DawnProcTable GetProcs() {
+    const DawnProcTable& GetProcs() {
         return GetProcsAutogen();
     }
 
diff --git a/src/dawn_wire/WireClient.cpp b/src/dawn_wire/WireClient.cpp
index e6fe263..430a55c 100644
--- a/src/dawn_wire/WireClient.cpp
+++ b/src/dawn_wire/WireClient.cpp
@@ -26,7 +26,7 @@
     }
 
     // static
-    DawnProcTable WireClient::GetProcs() {
+    const DawnProcTable& WireClient::GetProcs() {
         return client::GetProcs();
     }
 
diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h
index d8df86d..be70e75 100644
--- a/src/dawn_wire/client/Client.h
+++ b/src/dawn_wire/client/Client.h
@@ -68,7 +68,7 @@
         bool mIsDisconnected = false;
     };
 
-    DawnProcTable GetProcs();
+    const DawnProcTable& GetProcs();
 
     std::unique_ptr<MemoryTransferService> CreateInlineMemoryTransferService();
 
diff --git a/src/include/dawn/dawn_thread_dispatch_proc.h b/src/include/dawn/dawn_thread_dispatch_proc.h
new file mode 100644
index 0000000..4d08ba8
--- /dev/null
+++ b/src/include/dawn/dawn_thread_dispatch_proc.h
@@ -0,0 +1,33 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWN_DAWN_THREAD_DISPATCH_PROC_H_
+#define DAWN_DAWN_THREAD_DISPATCH_PROC_H_
+
+#include "dawn/dawn_proc.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Call dawnProcSetProcs(&dawnThreadDispatchProcTable) and then use dawnProcSetPerThreadProcs
+// to set per-thread procs.
+WGPU_EXPORT extern DawnProcTable dawnThreadDispatchProcTable;
+WGPU_EXPORT void dawnProcSetPerThreadProcs(const DawnProcTable* procs);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif
+
+#endif  // DAWN_DAWN_THREAD_DISPATCH_PROC_H_
diff --git a/src/include/dawn_native/DawnNative.h b/src/include/dawn_native/DawnNative.h
index 2199efa..6498646 100644
--- a/src/include/dawn_native/DawnNative.h
+++ b/src/include/dawn_native/DawnNative.h
@@ -170,7 +170,7 @@
     };
 
     // Backend-agnostic API for dawn_native
-    DAWN_NATIVE_EXPORT DawnProcTable GetProcs();
+    DAWN_NATIVE_EXPORT const DawnProcTable& GetProcs();
 
     // Query the names of all the toggles that are enabled in device
     DAWN_NATIVE_EXPORT std::vector<const char*> GetTogglesUsed(WGPUDevice device);
diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h
index 815b66b..50da913 100644
--- a/src/include/dawn_wire/WireClient.h
+++ b/src/include/dawn_wire/WireClient.h
@@ -26,6 +26,8 @@
     namespace client {
         class Client;
         class MemoryTransferService;
+
+        DAWN_WIRE_EXPORT const DawnProcTable& GetProcs();
     }  // namespace client
 
     struct ReservedTexture {
@@ -44,7 +46,8 @@
         WireClient(const WireClientDescriptor& descriptor);
         ~WireClient() override;
 
-        static DawnProcTable GetProcs();
+        // TODO(enga): Remove this and use dawn_wire::client::GetProcs() instead
+        static const DawnProcTable& GetProcs();
 
         WGPUDevice GetDevice() const;
         const volatile char* HandleCommands(const volatile char* commands,
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index 0077695..9c998f5 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -168,6 +168,7 @@
     "unittests/MathTests.cpp",
     "unittests/ObjectBaseTests.cpp",
     "unittests/PerStageTests.cpp",
+    "unittests/PerThreadProcTests.cpp",
     "unittests/PlacementAllocatedTests.cpp",
     "unittests/RefCountedTests.cpp",
     "unittests/ResultTests.cpp",
diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp
index 5873da8..945d9c1 100644
--- a/src/tests/DawnTest.cpp
+++ b/src/tests/DawnTest.cpp
@@ -744,12 +744,9 @@
         clientDesc.serializer = mC2sBuf.get();
 
         mWireClient.reset(new dawn_wire::WireClient(clientDesc));
-        WGPUDevice clientDevice = mWireClient->GetDevice();
-        DawnProcTable clientProcs = dawn_wire::WireClient::GetProcs();
+        cDevice = mWireClient->GetDevice();
+        procs = dawn_wire::client::GetProcs();
         mS2cBuf->SetHandler(mWireClient.get());
-
-        procs = clientProcs;
-        cDevice = clientDevice;
     } else {
         procs = backendProcs;
         cDevice = backendDevice;
diff --git a/src/tests/end2end/WindowSurfaceTests.cpp b/src/tests/end2end/WindowSurfaceTests.cpp
index 8aaa8cc..03e954e 100644
--- a/src/tests/end2end/WindowSurfaceTests.cpp
+++ b/src/tests/end2end/WindowSurfaceTests.cpp
@@ -50,8 +50,7 @@
         });
         DAWN_SKIP_TEST_IF(!glfwInit());
 
-        DawnProcTable procs = dawn_native::GetProcs();
-        dawnProcSetProcs(&procs);
+        dawnProcSetProcs(&dawn_native::GetProcs());
 
         mInstance = wgpu::CreateInstance();
     }
diff --git a/src/tests/unittests/GetProcAddressTests.cpp b/src/tests/unittests/GetProcAddressTests.cpp
index f5ac8c6..10f9c5c 100644
--- a/src/tests/unittests/GetProcAddressTests.cpp
+++ b/src/tests/unittests/GetProcAddressTests.cpp
@@ -72,7 +72,7 @@
                     mWireClient = std::make_unique<dawn_wire::WireClient>(clientDesc);
 
                     mDevice = wgpu::Device::Acquire(mWireClient->GetDevice());
-                    mProcs = dawn_wire::WireClient::GetProcs();
+                    mProcs = dawn_wire::client::GetProcs();
                     break;
                 }
 
diff --git a/src/tests/unittests/PerThreadProcTests.cpp b/src/tests/unittests/PerThreadProcTests.cpp
new file mode 100644
index 0000000..38ce981
--- /dev/null
+++ b/src/tests/unittests/PerThreadProcTests.cpp
@@ -0,0 +1,118 @@
+// Copyright 2020 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn/dawn_thread_dispatch_proc.h"
+#include "dawn/webgpu_cpp.h"
+#include "dawn_native/DawnNative.h"
+#include "dawn_native/Instance.h"
+#include "dawn_native/null/DeviceNull.h"
+
+#include <gtest/gtest.h>
+#include <atomic>
+#include <thread>
+
+class PerThreadProcTests : public testing::Test {
+  public:
+    PerThreadProcTests()
+        : mNativeInstance(dawn_native::InstanceBase::Create()),
+          mNativeAdapter(mNativeInstance.Get()) {
+    }
+    ~PerThreadProcTests() override = default;
+
+  protected:
+    Ref<dawn_native::InstanceBase> mNativeInstance;
+    dawn_native::null::Adapter mNativeAdapter;
+};
+
+// Test that procs can be set per thread. This test overrides deviceCreateBuffer with a dummy proc
+// for each thread that increments a counter. Because each thread has their own proc and counter,
+// there should be no data races. The per-thread procs also check that the current thread id is
+// exactly equal to the expected thread id.
+TEST_F(PerThreadProcTests, DispatchesPerThread) {
+    dawnProcSetProcs(&dawnThreadDispatchProcTable);
+
+    // Threads will block on this atomic to be sure we set procs on both threads before
+    // either thread calls the procs.
+    std::atomic<bool> ready(false);
+
+    static int threadACounter = 0;
+    static int threadBCounter = 0;
+
+    static std::atomic<std::thread::id> threadIdA;
+    static std::atomic<std::thread::id> threadIdB;
+
+    constexpr int kThreadATargetCount = 28347;
+    constexpr int kThreadBTargetCount = 40420;
+
+    // Note: Acquire doesn't call reference or release.
+    wgpu::Device deviceA =
+        wgpu::Device::Acquire(reinterpret_cast<WGPUDevice>(mNativeAdapter.CreateDevice(nullptr)));
+
+    wgpu::Device deviceB =
+        wgpu::Device::Acquire(reinterpret_cast<WGPUDevice>(mNativeAdapter.CreateDevice(nullptr)));
+
+    std::thread threadA([&]() {
+        DawnProcTable procs = dawn_native::GetProcs();
+        procs.deviceCreateBuffer = [](WGPUDevice device,
+                                      WGPUBufferDescriptor const* descriptor) -> WGPUBuffer {
+            EXPECT_EQ(std::this_thread::get_id(), threadIdA);
+            threadACounter++;
+            return nullptr;
+        };
+        dawnProcSetPerThreadProcs(&procs);
+
+        while (!ready) {
+        }  // Should be fast, so just spin.
+
+        for (int i = 0; i < kThreadATargetCount; ++i) {
+            deviceA.CreateBuffer(nullptr);
+        }
+
+        deviceA = nullptr;
+        dawnProcSetPerThreadProcs(nullptr);
+    });
+
+    std::thread threadB([&]() {
+        DawnProcTable procs = dawn_native::GetProcs();
+        procs.deviceCreateBuffer = [](WGPUDevice device,
+                                      WGPUBufferDescriptor const* bufferDesc) -> WGPUBuffer {
+            EXPECT_EQ(std::this_thread::get_id(), threadIdB);
+            threadBCounter++;
+            return nullptr;
+        };
+        dawnProcSetPerThreadProcs(&procs);
+
+        while (!ready) {
+        }  // Should be fast, so just spin.
+
+        for (int i = 0; i < kThreadBTargetCount; ++i) {
+            deviceB.CreateBuffer(nullptr);
+        }
+
+        deviceB = nullptr;
+        dawnProcSetPerThreadProcs(nullptr);
+    });
+
+    threadIdA = threadA.get_id();
+    threadIdB = threadB.get_id();
+
+    ready = true;
+    threadA.join();
+    threadB.join();
+
+    EXPECT_EQ(threadACounter, kThreadATargetCount);
+    EXPECT_EQ(threadBCounter, kThreadBTargetCount);
+
+    dawnProcSetProcs(nullptr);
+}
diff --git a/src/tests/unittests/validation/ValidationTest.cpp b/src/tests/unittests/validation/ValidationTest.cpp
index 058fd3b..1713171 100644
--- a/src/tests/unittests/validation/ValidationTest.cpp
+++ b/src/tests/unittests/validation/ValidationTest.cpp
@@ -40,8 +40,7 @@
 
     ASSERT(foundNullAdapter);
 
-    DawnProcTable procs = dawn_native::GetProcs();
-    dawnProcSetProcs(&procs);
+    dawnProcSetProcs(&dawn_native::GetProcs());
 
     device = CreateDeviceFromAdapter(adapter, std::vector<const char*>());
 }
diff --git a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
index 3ba0da0..b442794 100644
--- a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
+++ b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp
@@ -29,8 +29,7 @@
 class WireMultipleDeviceTests : public testing::Test {
   protected:
     void SetUp() override {
-        DawnProcTable procs = dawn_wire::WireClient::GetProcs();
-        dawnProcSetProcs(&procs);
+        dawnProcSetProcs(&dawn_wire::client::GetProcs());
     }
 
     void TearDown() override {
diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp
index bdac99f..d23709c 100644
--- a/src/tests/unittests/wire/WireTest.cpp
+++ b/src/tests/unittests/wire/WireTest.cpp
@@ -66,8 +66,7 @@
     mS2cBuf->SetHandler(mWireClient.get());
 
     device = mWireClient->GetDevice();
-    DawnProcTable clientProcs = dawn_wire::WireClient::GetProcs();
-    dawnProcSetProcs(&clientProcs);
+    dawnProcSetProcs(&dawn_wire::client::GetProcs());
 
     apiDevice = mockDevice;