dawn_wire::server: Simplify ForwardToServer usage with C++17

This uses template parameter type deduction to pass the member function
pointer and then extract the types that compose it. Which means that the
member function pointer only needs to be written once.

The order of arguments of the Server::On*Callback methods is changed to
put the userdata first. This helps make template type deduction simpler.

Bug: dawn:824
Change-Id: I4e2bc33dfd52a11620dea51b40508eca6c878d72
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/75071
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Brandon Jones <bajones@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h
index 1b685c7..e76f62b 100644
--- a/src/dawn_wire/server/Server.h
+++ b/src/dawn_wire/server/Server.h
@@ -42,13 +42,11 @@
     // auto userdata = MakeUserdata<MyUserdata>();
     // userdata->foo = 2;
     //
-    // // TODO(enga): Make the template inference for ForwardToServer cleaner with C++17
     // callMyCallbackHandler(
-    //      ForwardToServer<decltype(&Server::MyCallbackHandler)>::Func<
-    //                      &Server::MyCallbackHandler>(),
+    //      ForwardToServer<&Server::MyCallbackHandler>,
     //      userdata.release());
     //
-    // void Server::MyCallbackHandler(MyUserdata* userdata) { }
+    // void Server::MyCallbackHandler(MyUserdata* userdata, Other args) { }
     struct CallbackUserdata {
         Server* const server;
         std::weak_ptr<bool> const serverIsAlive;
@@ -59,53 +57,35 @@
         }
     };
 
-    template <typename F>
-    class ForwardToServer;
+    template <auto F>
+    struct ForwardToServerHelper {
+        template <typename _>
+        struct ExtractedTypes;
 
-    template <typename R, typename... Args>
-    class ForwardToServer<R (Server::*)(Args...)> {
-      private:
-        // Get the type T of the last argument. It has CallbackUserdata as its base.
-        using UserdataT = typename std::remove_pointer<typename std::decay<decltype(
-            std::get<sizeof...(Args) - 1>(std::declval<std::tuple<Args...>>()))>::type>::type;
-
-        static_assert(std::is_base_of<CallbackUserdata, UserdataT>::value,
-                      "Last argument of callback handler should derive from CallbackUserdata.");
-
-        template <class T, class... Ts>
-        struct UntypedCallbackImpl;
-
-        template <std::size_t... I, class... Ts>
-        struct UntypedCallbackImpl<std::index_sequence<I...>, Ts...> {
-            template <R (Server::*Func)(Args...)>
-            static auto ForwardToServer(
-                // Unpack and forward the types of the parameter pack.
-                // Append void* as the last argument.
-                typename std::tuple_element<I, std::tuple<Ts...>>::type... args,
-                void* userdata) {
+        // An internal structure used to unpack the various types that compose the type of F
+        template <typename Return, typename Class, typename Userdata, typename... Args>
+        struct ExtractedTypes<Return (Class::*)(Userdata*, Args...)> {
+            using UntypedCallback = Return (*)(Args..., void*);
+            static Return Callback(Args... args, void* userdata) {
                 // Acquire the userdata, and cast it to UserdataT.
-                std::unique_ptr<UserdataT> data(static_cast<UserdataT*>(userdata));
+                std::unique_ptr<Userdata> data(static_cast<Userdata*>(userdata));
                 if (data->serverIsAlive.expired()) {
                     // Do nothing if the server has already been destroyed.
                     return;
                 }
                 // Forward the arguments and the typed userdata to the Server:: member function.
-                (data->server->*Func)(std::forward<decltype(args)>(args)..., data.get());
+                (data->server->*F)(data.get(), std::forward<decltype(args)>(args)...);
             }
         };
 
-        // Generate a free function which has all of the same arguments, except the last
-        // userdata argument is void* instead of UserdataT*. Dawn's userdata args are void*.
-        using UntypedCallback =
-            UntypedCallbackImpl<std::make_index_sequence<sizeof...(Args) - 1>, Args...>;
-
-      public:
-        template <R (Server::*F)(Args...)>
-        static auto Func() {
-            return UntypedCallback::template ForwardToServer<F>;
+        static constexpr typename ExtractedTypes<decltype(F)>::UntypedCallback Create() {
+            return ExtractedTypes<decltype(F)>::Callback;
         }
     };
 
+    template <auto F>
+    constexpr auto ForwardToServer = ForwardToServerHelper<F>::Create();
+
     struct MapUserdata : CallbackUserdata {
         using CallbackUserdata::CallbackUserdata;
 
@@ -217,30 +197,30 @@
         void OnUncapturedError(ObjectHandle device, WGPUErrorType type, const char* message);
         void OnDeviceLost(ObjectHandle device, WGPUDeviceLostReason reason, const char* message);
         void OnLogging(ObjectHandle device, WGPULoggingType type, const char* message);
-        void OnDevicePopErrorScope(WGPUErrorType type,
-                                   const char* message,
-                                   ErrorScopeUserdata* userdata);
-        void OnBufferMapAsyncCallback(WGPUBufferMapAsyncStatus status, MapUserdata* userdata);
-        void OnQueueWorkDone(WGPUQueueWorkDoneStatus status, QueueWorkDoneUserdata* userdata);
-        void OnCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus status,
+        void OnDevicePopErrorScope(ErrorScopeUserdata* userdata,
+                                   WGPUErrorType type,
+                                   const char* message);
+        void OnBufferMapAsyncCallback(MapUserdata* userdata, WGPUBufferMapAsyncStatus status);
+        void OnQueueWorkDone(QueueWorkDoneUserdata* userdata, WGPUQueueWorkDoneStatus status);
+        void OnCreateComputePipelineAsyncCallback(CreatePipelineAsyncUserData* userdata,
+                                                  WGPUCreatePipelineAsyncStatus status,
                                                   WGPUComputePipeline pipeline,
-                                                  const char* message,
-                                                  CreatePipelineAsyncUserData* userdata);
-        void OnCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus status,
+                                                  const char* message);
+        void OnCreateRenderPipelineAsyncCallback(CreatePipelineAsyncUserData* userdata,
+                                                 WGPUCreatePipelineAsyncStatus status,
                                                  WGPURenderPipeline pipeline,
-                                                 const char* message,
-                                                 CreatePipelineAsyncUserData* userdata);
-        void OnShaderModuleGetCompilationInfo(WGPUCompilationInfoRequestStatus status,
-                                              const WGPUCompilationInfo* info,
-                                              ShaderModuleGetCompilationInfoUserdata* userdata);
-        void OnRequestAdapterCallback(WGPURequestAdapterStatus status,
+                                                 const char* message);
+        void OnShaderModuleGetCompilationInfo(ShaderModuleGetCompilationInfoUserdata* userdata,
+                                              WGPUCompilationInfoRequestStatus status,
+                                              const WGPUCompilationInfo* info);
+        void OnRequestAdapterCallback(RequestAdapterUserdata* userdata,
+                                      WGPURequestAdapterStatus status,
                                       WGPUAdapter adapter,
-                                      const char* message,
-                                      RequestAdapterUserdata* userdata);
-        void OnRequestDeviceCallback(WGPURequestDeviceStatus status,
+                                      const char* message);
+        void OnRequestDeviceCallback(RequestDeviceUserdata* userdata,
+                                     WGPURequestDeviceStatus status,
                                      WGPUDevice device,
-                                     const char* message,
-                                     RequestDeviceUserdata* userdata);
+                                     const char* message);
 
 #include "dawn_wire/server/ServerPrototypes_autogen.inc"
 
diff --git a/src/dawn_wire/server/ServerAdapter.cpp b/src/dawn_wire/server/ServerAdapter.cpp
index 652d86d..c70b8d8 100644
--- a/src/dawn_wire/server/ServerAdapter.cpp
+++ b/src/dawn_wire/server/ServerAdapter.cpp
@@ -39,18 +39,16 @@
         userdata->requestSerial = requestSerial;
         userdata->deviceObjectId = deviceHandle.id;
 
-        mProcs.adapterRequestDevice(
-            adapter->handle, descriptor,
-            ForwardToServer<decltype(
-                &Server::OnRequestDeviceCallback)>::Func<&Server::OnRequestDeviceCallback>(),
-            userdata.release());
+        mProcs.adapterRequestDevice(adapter->handle, descriptor,
+                                    ForwardToServer<&Server::OnRequestDeviceCallback>,
+                                    userdata.release());
         return true;
     }
 
-    void Server::OnRequestDeviceCallback(WGPURequestDeviceStatus status,
+    void Server::OnRequestDeviceCallback(RequestDeviceUserdata* data,
+                                         WGPURequestDeviceStatus status,
                                          WGPUDevice device,
-                                         const char* message,
-                                         RequestDeviceUserdata* data) {
+                                         const char* message) {
         auto* deviceObject = DeviceObjects().Get(data->deviceObjectId, AllocationState::Reserved);
         // Should be impossible to fail. ObjectIds can't be freed by a destroy command until
         // they move from Reserved to Allocated, or if they are destroyed here.
diff --git a/src/dawn_wire/server/ServerBuffer.cpp b/src/dawn_wire/server/ServerBuffer.cpp
index 05be903..86f011e 100644
--- a/src/dawn_wire/server/ServerBuffer.cpp
+++ b/src/dawn_wire/server/ServerBuffer.cpp
@@ -80,7 +80,7 @@
         // client does the default size computation, we should always have a valid actual size here
         // in server. All other invalid actual size can be caught by dawn native side validation.
         if (offset64 > std::numeric_limits<size_t>::max() || size64 >= WGPU_WHOLE_MAP_SIZE) {
-            OnBufferMapAsyncCallback(WGPUBufferMapAsyncStatus_Error, userdata.get());
+            OnBufferMapAsyncCallback(userdata.get(), WGPUBufferMapAsyncStatus_Error);
             return true;
         }
 
@@ -90,11 +90,9 @@
         userdata->offset = offset;
         userdata->size = size;
 
-        mProcs.bufferMapAsync(
-            buffer->handle, mode, offset, size,
-            ForwardToServer<decltype(
-                &Server::OnBufferMapAsyncCallback)>::Func<&Server::OnBufferMapAsyncCallback>(),
-            userdata.release());
+        mProcs.bufferMapAsync(buffer->handle, mode, offset, size,
+                              ForwardToServer<&Server::OnBufferMapAsyncCallback>,
+                              userdata.release());
 
         return true;
     }
@@ -227,7 +225,7 @@
             static_cast<size_t>(offset), static_cast<size_t>(size));
     }
 
-    void Server::OnBufferMapAsyncCallback(WGPUBufferMapAsyncStatus status, MapUserdata* data) {
+    void Server::OnBufferMapAsyncCallback(MapUserdata* data, WGPUBufferMapAsyncStatus status) {
         // Skip sending the callback if the buffer has already been destroyed.
         auto* bufferData = BufferObjects().Get(data->buffer.id);
         if (bufferData == nullptr || bufferData->generation != data->buffer.generation) {
diff --git a/src/dawn_wire/server/ServerDevice.cpp b/src/dawn_wire/server/ServerDevice.cpp
index c8cddf4..139ed3d 100644
--- a/src/dawn_wire/server/ServerDevice.cpp
+++ b/src/dawn_wire/server/ServerDevice.cpp
@@ -91,19 +91,16 @@
 
         ErrorScopeUserdata* unownedUserdata = userdata.release();
         bool success = mProcs.devicePopErrorScope(
-            device->handle,
-            ForwardToServer<decltype(
-                &Server::OnDevicePopErrorScope)>::Func<&Server::OnDevicePopErrorScope>(),
-            unownedUserdata);
+            device->handle, ForwardToServer<&Server::OnDevicePopErrorScope>, unownedUserdata);
         if (!success) {
             delete unownedUserdata;
         }
         return success;
     }
 
-    void Server::OnDevicePopErrorScope(WGPUErrorType type,
-                                       const char* message,
-                                       ErrorScopeUserdata* userdata) {
+    void Server::OnDevicePopErrorScope(ErrorScopeUserdata* userdata,
+                                       WGPUErrorType type,
+                                       const char* message) {
         ReturnDevicePopErrorScopeCallbackCmd cmd;
         cmd.device = userdata->device;
         cmd.requestSerial = userdata->requestSerial;
@@ -139,16 +136,14 @@
 
         mProcs.deviceCreateComputePipelineAsync(
             device->handle, descriptor,
-            ForwardToServer<decltype(&Server::OnCreateComputePipelineAsyncCallback)>::Func<
-                &Server::OnCreateComputePipelineAsyncCallback>(),
-            userdata.release());
+            ForwardToServer<&Server::OnCreateComputePipelineAsyncCallback>, userdata.release());
         return true;
     }
 
-    void Server::OnCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus status,
+    void Server::OnCreateComputePipelineAsyncCallback(CreatePipelineAsyncUserData* data,
+                                                      WGPUCreatePipelineAsyncStatus status,
                                                       WGPUComputePipeline pipeline,
-                                                      const char* message,
-                                                      CreatePipelineAsyncUserData* data) {
+                                                      const char* message) {
         HandleCreateRenderPipelineAsyncCallbackResult<ObjectType::ComputePipeline>(
             &ComputePipelineObjects(), status, pipeline, data);
 
@@ -186,16 +181,14 @@
 
         mProcs.deviceCreateRenderPipelineAsync(
             device->handle, descriptor,
-            ForwardToServer<decltype(&Server::OnCreateRenderPipelineAsyncCallback)>::Func<
-                &Server::OnCreateRenderPipelineAsyncCallback>(),
-            userdata.release());
+            ForwardToServer<&Server::OnCreateRenderPipelineAsyncCallback>, userdata.release());
         return true;
     }
 
-    void Server::OnCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus status,
+    void Server::OnCreateRenderPipelineAsyncCallback(CreatePipelineAsyncUserData* data,
+                                                     WGPUCreatePipelineAsyncStatus status,
                                                      WGPURenderPipeline pipeline,
-                                                     const char* message,
-                                                     CreatePipelineAsyncUserData* data) {
+                                                     const char* message) {
         HandleCreateRenderPipelineAsyncCallbackResult<ObjectType::RenderPipeline>(
             &RenderPipelineObjects(), status, pipeline, data);
 
diff --git a/src/dawn_wire/server/ServerInstance.cpp b/src/dawn_wire/server/ServerInstance.cpp
index 0fb35c0..18d6740 100644
--- a/src/dawn_wire/server/ServerInstance.cpp
+++ b/src/dawn_wire/server/ServerInstance.cpp
@@ -41,18 +41,16 @@
         userdata->requestSerial = requestSerial;
         userdata->adapterObjectId = adapterHandle.id;
 
-        mProcs.instanceRequestAdapter(
-            instance->handle, options,
-            ForwardToServer<decltype(
-                &Server::OnRequestAdapterCallback)>::Func<&Server::OnRequestAdapterCallback>(),
-            userdata.release());
+        mProcs.instanceRequestAdapter(instance->handle, options,
+                                      ForwardToServer<&Server::OnRequestAdapterCallback>,
+                                      userdata.release());
         return true;
     }
 
-    void Server::OnRequestAdapterCallback(WGPURequestAdapterStatus status,
+    void Server::OnRequestAdapterCallback(RequestAdapterUserdata* data,
+                                          WGPURequestAdapterStatus status,
                                           WGPUAdapter adapter,
-                                          const char* message,
-                                          RequestAdapterUserdata* data) {
+                                          const char* message) {
         auto* adapterObject =
             AdapterObjects().Get(data->adapterObjectId, AllocationState::Reserved);
         // Should be impossible to fail. ObjectIds can't be freed by a destroy command until
diff --git a/src/dawn_wire/server/ServerQueue.cpp b/src/dawn_wire/server/ServerQueue.cpp
index 08a5925..54f573e 100644
--- a/src/dawn_wire/server/ServerQueue.cpp
+++ b/src/dawn_wire/server/ServerQueue.cpp
@@ -17,7 +17,7 @@
 
 namespace dawn_wire { namespace server {
 
-    void Server::OnQueueWorkDone(WGPUQueueWorkDoneStatus status, QueueWorkDoneUserdata* data) {
+    void Server::OnQueueWorkDone(QueueWorkDoneUserdata* data, WGPUQueueWorkDoneStatus status) {
         ReturnQueueWorkDoneCallbackCmd cmd;
         cmd.queue = data->queue;
         cmd.requestSerial = data->requestSerial;
@@ -38,10 +38,9 @@
         userdata->queue = ObjectHandle{queueId, queue->generation};
         userdata->requestSerial = requestSerial;
 
-        mProcs.queueOnSubmittedWorkDone(
-            queue->handle, signalValue,
-            ForwardToServer<decltype(&Server::OnQueueWorkDone)>::Func<&Server::OnQueueWorkDone>(),
-            userdata.release());
+        mProcs.queueOnSubmittedWorkDone(queue->handle, signalValue,
+                                        ForwardToServer<&Server::OnQueueWorkDone>,
+                                        userdata.release());
         return true;
     }
 
diff --git a/src/dawn_wire/server/ServerShaderModule.cpp b/src/dawn_wire/server/ServerShaderModule.cpp
index cec0dc4..3a7a671 100644
--- a/src/dawn_wire/server/ServerShaderModule.cpp
+++ b/src/dawn_wire/server/ServerShaderModule.cpp
@@ -29,16 +29,14 @@
         userdata->requestSerial = requestSerial;
 
         mProcs.shaderModuleGetCompilationInfo(
-            shaderModule->handle,
-            ForwardToServer<decltype(&Server::OnShaderModuleGetCompilationInfo)>::Func<
-                &Server::OnShaderModuleGetCompilationInfo>(),
+            shaderModule->handle, ForwardToServer<&Server::OnShaderModuleGetCompilationInfo>,
             userdata.release());
         return true;
     }
 
-    void Server::OnShaderModuleGetCompilationInfo(WGPUCompilationInfoRequestStatus status,
-                                                  const WGPUCompilationInfo* info,
-                                                  ShaderModuleGetCompilationInfoUserdata* data) {
+    void Server::OnShaderModuleGetCompilationInfo(ShaderModuleGetCompilationInfoUserdata* data,
+                                                  WGPUCompilationInfoRequestStatus status,
+                                                  const WGPUCompilationInfo* info) {
         ReturnShaderModuleGetCompilationInfoCallbackCmd cmd;
         cmd.shaderModule = data->shaderModule;
         cmd.requestSerial = data->requestSerial;