[wgpu-headers] Introduce 2nd userdata to requestAdapter callback.
- Updates the wire to use the new APIs
- Adds new infra needed for C++ templated helpers
- Updates relevant usages to use new C++ helpers
Bug: 42241461
Change-Id: Ib6104693ab89af2166ffc56ae90b3a24198e5b31
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/183163
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py
index 11e8809..5e5289b 100644
--- a/generator/dawn_json_generator.py
+++ b/generator/dawn_json_generator.py
@@ -161,6 +161,14 @@
self.is_wire_transparent = True
+class CallbackFunctionType(Type):
+
+ def __init__(self, is_enabled, name, json_data):
+ Type.__init__(self, name, json_data)
+ self.return_type = None
+ self.arguments = []
+
+
class FunctionPointerType(Type):
def __init__(self, is_enabled, name, json_data):
Type.__init__(self, name, json_data)
@@ -328,6 +336,14 @@
return all((member.optional or member.default_value)
for member in self.members)
+
+class CallbackInfoType(StructureType):
+
+ def __init__(self, is_enabled, name, json_data):
+ StructureType.__init__(self, is_enabled, name, json_data)
+ self.extensible = 'in'
+
+
class ConstantDefinition():
def __init__(self, is_enabled, name, json_data):
self.type = None
@@ -502,6 +518,8 @@
disabled_tags, json_data)
category_to_parser = {
'bitmask': BitmaskType,
+ 'callback function': CallbackFunctionType,
+ 'callback info': CallbackInfoType,
'enum': EnumType,
'native': NativeType,
'function pointer': FunctionPointerType,
@@ -549,6 +567,12 @@
types[name] = func_decl
by_category['function'].append(func_decl)
+ for callback_info in by_category['callback info']:
+ link_structure(callback_info, types)
+
+ for callback_function in by_category['callback function']:
+ link_function_pointer(callback_function, types)
+
for function_pointer in by_category['function pointer']:
link_function_pointer(function_pointer, types)
@@ -956,9 +980,25 @@
return any(arg.type.category == 'function pointer' for arg in method.arguments)
+# TODO: crbug.com/dawn/2509 - Remove this helper when once we deprecate older APIs.
def has_callback_info(method):
return method.return_type.name.get() == 'future' and any(
- arg.name.get() == 'callback info' for arg in method.arguments)
+ arg.name.get() == 'callback info'
+ and arg.type.category != 'callback info' for arg in method.arguments)
+
+
+def has_callbackInfoStruct(method):
+ return any(arg.type.category == 'callback info'
+ for arg in method.arguments)
+
+
+def is_wire_serializable(type):
+ # Function pointers, callback functions, and "void *" types (i.e. userdata) cannot
+ # be serialized.
+ return (type.category != 'function pointer'
+ and type.category != 'callback function'
+ and type.name.get() != 'void *')
+
def make_base_render_params(metadata):
c_prefix = metadata.c_prefix
@@ -1016,6 +1056,7 @@
'as_varName': as_varName,
'decorate': decorate,
'as_ktName': as_ktName,
+ 'has_callbackInfoStruct': has_callbackInfoStruct,
}
@@ -1273,6 +1314,7 @@
'as_wireType': lambda type : as_wireType(metadata, type),
'as_annotated_wireType': \
lambda arg: annotated(as_wireType(metadata, arg.type), arg),
+ 'is_wire_serializable': lambda type : is_wire_serializable(type),
}, additional_params
]
renders.append(
diff --git a/generator/templates/api.h b/generator/templates/api.h
index c4c1a3d..43b495e 100644
--- a/generator/templates/api.h
+++ b/generator/templates/api.h
@@ -128,6 +128,15 @@
) {{API}}_FUNCTION_ATTRIBUTE;
{% endfor %}
+// Callback function pointers
+{% for type in by_category["callback function"] %}
+ typedef {{as_cType(type.return_type.name)}} (*{{as_cType(type.name)}})(
+ {%- for arg in type.arguments -%}
+ {{as_annotated_cType(arg)}}{{", "}}
+ {%- endfor -%}
+ void* userdata1, void* userdata2) {{API}}_FUNCTION_ATTRIBUTE;
+{% endfor %}
+
typedef struct {{API}}ChainedStruct {
struct {{API}}ChainedStruct const * next;
{{API}}SType sType;
@@ -143,6 +152,8 @@
nullptr
{%- elif member.type.category == "object" and member.optional -%}
nullptr
+ {%- elif member.type.category == "callback function" -%}
+ nullptr
{%- elif member.type.category in ["enum", "bitmask"] and member.default_value != None -%}
{{as_cEnum(member.type.name, Name(member.default_value))}}
{%- elif member.default_value != None -%}
@@ -193,6 +204,28 @@
})
{% endfor %}
+{% for type in by_category["callback info"] %}
+ typedef struct {{as_cType(type.name)}} {
+ {{API}}ChainedStruct const* nextInChain;
+ {{as_cType(types["callback mode"].name)}} mode;
+ {% for member in type.members %}
+ //* Only callback function types are allowed in callback info structs.
+ {{assert(member.type.category == "callback function")}}{{as_annotated_cType(member)}};
+ {% endfor %}
+ void* userdata1;
+ void* userdata2;
+ } {{as_cType(type.name)}} {{API}}_STRUCTURE_ATTRIBUTE;
+
+ #define {{API}}_{{type.name.SNAKE_CASE()}}_INIT {{API}}_MAKE_INIT_STRUCT({{as_cType(type.name)}}, { \
+ /*.mode=*/{{as_cEnum(types["callback mode"].name, Name("undefined"))}} {{API}}_COMMA \
+ {% for member in type.members %}
+ /*.{{as_varName(member.name)}}=*/{{render_c_default_value(member)}} {{API}}_COMMA \
+ {% endfor %}
+ /*.userdata1=*/nullptr {{API}}_COMMA \
+ /*.userdata2=*/nullptr {{API}}_COMMA \
+ })
+
+{% endfor %}
{% for typeDef in by_category["typedef"] %}
// {{as_cType(typeDef.name)}} is deprecated.
// Use {{as_cType(typeDef.type.name)}} instead.
diff --git a/generator/templates/api_cpp.h b/generator/templates/api_cpp.h
index f450d30..2af158e 100644
--- a/generator/templates/api_cpp.h
+++ b/generator/templates/api_cpp.h
@@ -38,15 +38,12 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
+#include <memory>
#include <functional>
#include "{{c_header}}"
#include "{{api}}/{{api}}_cpp_chained_struct.h"
#include "{{api}}/{{api}}_enum_class_bitmasks.h"
-#include <cmath>
-#include <cstddef>
-#include <cstdint>
-#include <functional>
namespace {{metadata.namespace}} {
@@ -237,6 +234,39 @@
{%- endif -%}
{%- endmacro %}
+//* This rendering macro should ONLY be used for callback info type functions.
+{% macro render_cpp_callback_info_template_method_declaration(type, method, dfn=False) %}
+ {% set CppType = as_cppType(type.name) %}
+ {% set OriginalMethodName = method.name.CamelCase() %}
+ {% set MethodName = OriginalMethodName[:-1] if method.name.chunks[-1] == "2" else OriginalMethodName %}
+ {% set MethodName = CppType + "::" + MethodName if dfn else MethodName %}
+ //* Stripping the 2 at the end of the callback functions for now until we can deprecate old ones.
+ //* TODO: crbug.com/dawn/2509 - Remove name handling once old APIs are deprecated.
+ {% set CallbackInfoType = (method.arguments|last).type %}
+ {% set CallbackType = (CallbackInfoType.members|first).type %}
+ {% set SfinaeArg = " = std::enable_if_t<std::is_convertible_v<F, Cb*>>" if not dfn else "" %}
+ template <typename F, typename T,
+ typename Cb
+ {%- if not dfn -%}
+ {{" "}}= void (
+ {%- for arg in CallbackType.arguments -%}
+ {{as_annotated_cppType(arg)}}{{", "}}
+ {%- endfor -%}
+ T userdata)
+ {%- endif -%},
+ typename{{SfinaeArg}}>
+ {{as_cppType(method.return_type.name)}} {{MethodName}}(
+ {%- for arg in method.arguments if arg.type.category != "callback info" -%}
+ {%- if arg.type.category == "object" and arg.annotation == "value" -%}
+ {{as_cppType(arg.type.name)}} const& {{as_varName(arg.name)}}{{ ", "}}
+ {%- else -%}
+ {{as_annotated_cppType(arg)}}{{ ", "}}
+ {%- endif -%}
+ {%- endfor -%}
+ {{as_cppType(types["callback mode"].name)}} mode, F callback, T userdata) const
+{%- endmacro %}
+
+//* This rendering macro should NOT be used for callback info type functions.
{% macro render_cpp_method_declaration(type, method, dfn=False) %}
{% set CppType = as_cppType(type.name) %}
{% set OriginalMethodName = method.name.CamelCase() %}
@@ -281,6 +311,48 @@
struct {{as_cppType(type.name)}};
{% endfor %}
+{% macro render_cpp_callback_info_template_method_impl(type, method) %}
+ {{render_cpp_callback_info_template_method_declaration(type, method, dfn=True)}} {
+ {% set CallbackInfoType = (method.arguments|last).type %}
+ {% set CallbackType = (CallbackInfoType.members|first).type %}
+ {{as_cType(CallbackInfoType.name)}} callbackInfo = {};
+ callbackInfo.mode = static_cast<{{as_cType(types["callback mode"].name)}}>(mode);
+ callbackInfo.callback = [](
+ {%- for arg in CallbackType.arguments -%}
+ {{as_annotated_cType(arg)}}{{", "}}
+ {%- endfor -%}
+ void* callback, void* userdata) {
+ auto cb = reinterpret_cast<Cb*>(callback);
+ (*cb)(
+ {%- for arg in CallbackType.arguments -%}
+ {{convert_cType_to_cppType(arg.type, arg.annotation, as_varName(arg.name))}}{{", "}}
+ {%- endfor -%}
+ static_cast<T>(userdata));
+ };
+ callbackInfo.userdata1 = reinterpret_cast<void*>(+callback);
+ callbackInfo.userdata2 = reinterpret_cast<void*>(userdata);
+ auto result = {{as_cMethod(type.name, method.name)}}(Get(){{", "}}
+ {%- for arg in method.arguments if arg.type.category != "callback info" -%}{{render_c_actual_arg(arg)}}{{", "}}
+ {%- endfor -%}
+ callbackInfo);
+ return {{convert_cType_to_cppType(method.return_type, 'value', 'result') | indent(8)}};
+ }
+{%- endmacro %}
+
+{% macro render_cpp_method_impl(type, method) %}
+ {{render_cpp_method_declaration(type, method, dfn=True)}} {
+ {% for arg in method.arguments if arg.type.has_free_members_function and arg.annotation == '*' %}
+ *{{as_varName(arg.name)}} = {{as_cppType(arg.type.name)}}();
+ {% endfor %}
+ {% if method.return_type.name.concatcase() == "void" %}
+ {{render_cpp_to_c_method_call(type, method)}};
+ {% else %}
+ auto result = {{render_cpp_to_c_method_call(type, method)}};
+ return {{convert_cType_to_cppType(method.return_type, 'value', 'result') | indent(8)}};
+ {% endif %}
+ }
+{%- endmacro %}
+
{% for type in by_category["object"] %}
{% set CppType = as_cppType(type.name) %}
{% set CType = as_cType(type.name) %}
@@ -290,7 +362,11 @@
using ObjectBase::operator=;
{% for method in type.methods %}
- inline {{render_cpp_method_declaration(type, method)}};
+ {% if has_callbackInfoStruct(method) %}
+ {{render_cpp_callback_info_template_method_declaration(type, method)|indent}};
+ {% else %}
+ inline {{render_cpp_method_declaration(type, method)}};
+ {% endif %}
{% endfor %}
private:
@@ -462,17 +538,11 @@
// {{CppType}} implementation
{% for method in type.methods %}
- {{render_cpp_method_declaration(type, method, dfn=True)}} {
- {% for arg in method.arguments if arg.type.has_free_members_function and arg.annotation == '*' %}
- *{{as_varName(arg.name)}} = {{as_cppType(arg.type.name)}}();
- {% endfor %}
- {% if method.return_type.name.concatcase() == "void" %}
- {{render_cpp_to_c_method_call(type, method)}};
- {% else %}
- auto result = {{render_cpp_to_c_method_call(type, method)}};
- return {{convert_cType_to_cppType(method.return_type, 'value', 'result') | indent(8)}};
- {% endif %}
- }
+ {% if has_callbackInfoStruct(method) %}
+ {{render_cpp_callback_info_template_method_impl(type, method)}}
+ {% else %}
+ {{render_cpp_method_impl(type, method)}}
+ {% endif %}
{% endfor %}
void {{CppType}}::{{c_prefix}}AddRef({{CType}} handle) {
if (handle != nullptr) {
diff --git a/generator/templates/dawn/wire/WireCmd.cpp b/generator/templates/dawn/wire/WireCmd.cpp
index ec1faaa..3cadd30 100644
--- a/generator/templates/dawn/wire/WireCmd.cpp
+++ b/generator/templates/dawn/wire/WireCmd.cpp
@@ -84,8 +84,7 @@
//* trusted boundary.
{%- set Provider = ", provider" if member.type.may_have_dawn_object else "" -%}
WIRE_TRY({{as_cType(member.type.name)}}Serialize({{in}}, &{{out}}, buffer{{Provider}}));
- {%- elif member.type.category == "function pointer" or member.type.name.get() == "void *" -%}
- //* Function pointers and explicit "void *" types (i.e. userdata) cannot be serialized.
+ {%- elif not is_wire_serializable(member.type) -%}
if ({{in}} != nullptr) return WireResult::FatalError;
{%- else -%}
{{out}} = {{in}};
@@ -108,8 +107,7 @@
{%- endif -%}
));
{%- endif -%}
- {%- elif member.type.category == "function pointer" or member.type.name.get() == "void *" %}
- //* Function pointers and explicit "void *" types (i.e. userdata) cannot be deserialized.
+ {%- elif not is_wire_serializable(member.type) %}
{{out}} = nullptr;
{%- elif member.type.name.get() == "size_t" -%}
//* Deserializing into size_t requires check that the uint64_t used on the wire won't narrow.
@@ -146,8 +144,7 @@
{% endif %}
{% for member in members %}
- //* Function pointers and explicit "void *" types (i.e. userdata) do not get serialized.
- {% if member.type.category == "function pointer" or member.type.name.get() == "void *" %}
+ {% if not is_wire_serializable(member.type) %}
{% continue %}
{% endif %}
//* Value types are directly in the command, objects being replaced with their IDs.
diff --git a/generator/templates/mock_api.cpp b/generator/templates/mock_api.cpp
index c7eeb8c..7cd252a 100644
--- a/generator/templates/mock_api.cpp
+++ b/generator/templates/mock_api.cpp
@@ -147,6 +147,41 @@
);
return {mNextFutureID++};
}
+ {% elif has_callbackInfoStruct(method) %}
+ {{as_cType(method.return_type.name)}} ProcTableAsClass::{{Suffix}}(
+ {{-as_cType(type.name)}} {{as_varName(type.name)}}
+ {%- for arg in method.arguments -%}
+ , {{as_annotated_cType(arg)}}
+ {%- endfor -%}
+ ) {
+ ProcTableAsClass::Object* object = reinterpret_cast<ProcTableAsClass::Object*>({{as_varName(type.name)}});
+ object->m{{Suffix}}Callback = callbackInfo.callback;
+ object->m{{Suffix}}Userdata1 = callbackInfo.userdata1;
+ object->m{{Suffix}}Userdata2 = callbackInfo.userdata2;
+
+ On{{Suffix}}(
+ {{-as_varName(type.name)}}
+ {%- for arg in method.arguments -%}
+ , {{as_varName(arg.name)}}
+ {%- endfor -%}
+ );
+ return {mNextFutureID++};
+ }
+ {% set CallbackInfoType = (method.arguments|last).type %}
+ {% set CallbackType = (CallbackInfoType.members|first).type %}
+ void ProcTableAsClass::Call{{Suffix}}Callback(
+ {{-as_cType(type.name)}} {{as_varName(type.name)}}
+ {%- for arg in CallbackType.arguments -%}
+ , {{as_annotated_cType(arg)}}
+ {%- endfor -%}
+ ) {
+ ProcTableAsClass::Object* object = reinterpret_cast<ProcTableAsClass::Object*>({{as_varName(type.name)}});
+ object->m{{Suffix}}Callback(
+ {%- for arg in CallbackType.arguments -%}
+ {{as_varName(arg.name)}}{{", "}}
+ {%- endfor -%}
+ object->m{{Suffix}}Userdata1, object->m{{Suffix}}Userdata2);
+ }
{% endif %}
{% endfor %}
diff --git a/generator/templates/mock_api.h b/generator/templates/mock_api.h
index 7e9fc4a..e765d0e 100644
--- a/generator/templates/mock_api.h
+++ b/generator/templates/mock_api.h
@@ -78,7 +78,7 @@
}
{% for method in type.methods if method.name.get() not in ManuallyMockedFunctions %}
{% set Suffix = as_CppMethodSuffix(type.name, method.name) %}
- {% if not has_callback_arguments(method) and not has_callback_info(method) %}
+ {% if not has_callback_arguments(method) and not has_callback_info(method) and not has_callbackInfoStruct(method) %}
virtual {{as_cType(method.return_type.name)}} {{Suffix}}(
{{-as_cType(type.name)}} {{as_varName(type.name)}}
{%- for arg in method.arguments -%}
@@ -96,6 +96,26 @@
{% endif %}
{% endfor %}
+ {% for method in type.methods if has_callbackInfoStruct(method) %}
+ {% set Suffix = as_CppMethodSuffix(type.name, method.name) %}
+ //* The virtual function to call after saving the callback and userdata in the proc.
+ //* This function can be mocked.
+ virtual void On{{Suffix}}(
+ {{-as_cType(type.name)}} {{as_varName(type.name)}}
+ {%- for arg in method.arguments -%}
+ , {{as_annotated_cType(arg)}}
+ {%- endfor -%}
+ ) = 0;
+ {% set CallbackInfoType = (method.arguments|last).type %}
+ {% set CallbackType = (CallbackInfoType.members|first).type %}
+ void Call{{Suffix}}Callback(
+ {{-as_cType(type.name)}} {{as_varName(type.name)}}
+ {%- for arg in CallbackType.arguments -%}
+ , {{as_annotated_cType(arg)}}
+ {%- endfor -%}
+ );
+ {% endfor %}
+
{%- for method in type.methods if has_callback_info(method) or method.name.get() in LegacyCallbackFunctions %}
{% set Suffix = as_CppMethodSuffix(type.name, method.name) %}
@@ -155,6 +175,13 @@
{% endif %}
{% endfor %}
{% endfor %}
+ {% for method in type.methods if has_callbackInfoStruct(method) %}
+ {% set CallbackInfoType = (method.arguments|last).type %}
+ {% set CallbackType = (CallbackInfoType.members|first).type %}
+ void* m{{as_CppMethodSuffix(type.name, method.name)}}Userdata1 = 0;
+ void* m{{as_CppMethodSuffix(type.name, method.name)}}Userdata2 = 0;
+ {{as_cType(CallbackType.name)}} m{{as_CppMethodSuffix(type.name, method.name)}}Callback = nullptr;
+ {% endfor %}
{% endfor %}
// Manually implement device lost related callback helpers for testing.
WGPUDeviceLostCallback mDeviceLostOldCallback = nullptr;
@@ -180,7 +207,7 @@
MOCK_METHOD(void, {{as_MethodSuffix(type.name, Name("add ref"))}}, ({{as_cType(type.name)}} self), (override));
MOCK_METHOD(void, {{as_MethodSuffix(type.name, Name("release"))}}, ({{as_cType(type.name)}} self), (override));
- {% for method in type.methods if not has_callback_arguments(method) and not has_callback_info(method) %}
+ {% for method in type.methods if not has_callback_arguments(method) and not has_callback_info(method) and not has_callbackInfoStruct(method) %}
MOCK_METHOD({{as_cType(method.return_type.name)}},{{" "}}
{{-as_MethodSuffix(type.name, method.name)}}, (
{{-as_cType(type.name)}} {{as_varName(type.name)}}
@@ -199,6 +226,15 @@
{%- endfor -%}
), (override));
{% endfor %}
+ {% for method in type.methods if has_callbackInfoStruct(method) %}
+ MOCK_METHOD(void,{{" "-}}
+ On{{as_CppMethodSuffix(type.name, method.name)}}, (
+ {{-as_cType(type.name)}} {{as_varName(type.name)}}
+ {%- for arg in method.arguments -%}
+ , {{as_annotated_cType(arg)}}
+ {%- endfor -%}
+ ), (override));
+ {% endfor %}
{% endfor %}
// Manually implement device lost related callback helpers for testing.
diff --git a/src/dawn/dawn.json b/src/dawn/dawn.json
index 57198de..fd97565 100644
--- a/src/dawn/dawn.json
+++ b/src/dawn/dawn.json
@@ -81,6 +81,14 @@
{"name": "userdata", "type": "void *"}
]
},
+ "request adapter callback 2": {
+ "category": "callback function",
+ "args": [
+ {"name": "status", "type": "request adapter status"},
+ {"name": "adapter", "type": "adapter", "optional": true},
+ {"name": "message", "type": "char", "annotation": "const*", "length": "strlen", "optional": true}
+ ]
+ },
"request adapter callback info": {
"category": "structure",
"extensible": "in",
@@ -90,6 +98,12 @@
{"name": "userdata", "type": "void *"}
]
},
+ "request adapter callback info 2": {
+ "category": "callback info",
+ "members": [
+ {"name": "callback", "type": "request adapter callback 2"}
+ ]
+ },
"request adapter status": {
"category": "enum",
"emscripten_no_enum_table": true,
@@ -2397,6 +2411,16 @@
]
},
{
+ "name": "request adapter 2",
+ "_comment": "TODO(crbug.com/dawn/2021): This is dawn/emscripten-only until we rename it to replace the old API. See bug for details.",
+ "tags": ["dawn", "emscripten"],
+ "returns": "future",
+ "args": [
+ {"name": "options", "type": "request adapter options", "annotation": "const*", "optional": true, "no_default": true},
+ {"name": "callback info", "type": "request adapter callback info 2"}
+ ]
+ },
+ {
"name": "has WGSL language feature",
"returns": "bool",
"args": [
diff --git a/src/dawn/dawn_wire.json b/src/dawn/dawn_wire.json
index 3f3c7d0e..62d42bd 100644
--- a/src/dawn/dawn_wire.json
+++ b/src/dawn/dawn_wire.json
@@ -108,7 +108,8 @@
{ "name": "event manager handle", "type": "ObjectHandle" },
{ "name": "future", "type": "future" },
{ "name": "adapter object handle", "type": "ObjectHandle", "handle_type": "adapter"},
- { "name": "options", "type": "request adapter options", "annotation": "const*", "optional": true }
+ { "name": "options", "type": "request adapter options", "annotation": "const*", "optional": true },
+ { "name": "userdata count", "type": "uint8_t", "_comment": "TODO(crbug.com/dawn/2509): Remove this once Chromium overrides the correct functions in the proc table."}
],
"adapter request device": [
{ "name": "adapter id", "type": "ObjectId", "id_type": "adapter" },
@@ -235,6 +236,7 @@
"InstanceHasWGSLLanguageFeature",
"InstanceRequestAdapter",
"InstanceRequestAdapterF",
+ "InstanceRequestAdapter2",
"ShaderModuleGetCompilationInfo",
"ShaderModuleGetCompilationInfoF",
"QuerySetGetType",
diff --git a/src/dawn/native/Instance.cpp b/src/dawn/native/Instance.cpp
index a0a7b41..76945fa 100644
--- a/src/dawn/native/Instance.cpp
+++ b/src/dawn/native/Instance.cpp
@@ -282,16 +282,31 @@
Future InstanceBase::APIRequestAdapterF(const RequestAdapterOptions* options,
const RequestAdapterCallbackInfo& callbackInfo) {
+ return APIRequestAdapter2(
+ options, {nullptr, ToAPI(callbackInfo.mode),
+ [](WGPURequestAdapterStatus status, WGPUAdapter adapter, char const* message,
+ void* callback, void* userdata) {
+ auto cb = reinterpret_cast<WGPURequestAdapterCallback>(callback);
+ cb(status, adapter, message, userdata);
+ },
+ reinterpret_cast<void*>(callbackInfo.callback), callbackInfo.userdata});
+}
+
+Future InstanceBase::APIRequestAdapter2(const RequestAdapterOptions* options,
+ const WGPURequestAdapterCallbackInfo2& callbackInfo) {
struct RequestAdapterEvent final : public EventManager::TrackedEvent {
- WGPURequestAdapterCallback mCallback;
- raw_ptr<void> mUserdata;
+ WGPURequestAdapterCallback2 mCallback;
+ raw_ptr<void> mUserdata1;
+ raw_ptr<void> mUserdata2;
Ref<AdapterBase> mAdapter;
- RequestAdapterEvent(const RequestAdapterCallbackInfo& callbackInfo,
+ RequestAdapterEvent(const WGPURequestAdapterCallbackInfo2& callbackInfo,
Ref<AdapterBase> adapter)
- : TrackedEvent(callbackInfo.mode, TrackedEvent::Completed{}),
+ : TrackedEvent(static_cast<wgpu::CallbackMode>(callbackInfo.mode),
+ TrackedEvent::Completed{}),
mCallback(callbackInfo.callback),
- mUserdata(callbackInfo.userdata),
+ mUserdata1(callbackInfo.userdata1),
+ mUserdata2(callbackInfo.userdata2),
mAdapter(std::move(adapter)) {}
~RequestAdapterEvent() override { EnsureComplete(EventCompletionType::Shutdown); }
@@ -299,17 +314,17 @@
void Complete(EventCompletionType completionType) override {
if (completionType == EventCompletionType::Shutdown) {
mCallback(WGPURequestAdapterStatus_InstanceDropped, nullptr, nullptr,
- mUserdata.ExtractAsDangling());
+ mUserdata1.ExtractAsDangling(), mUserdata2.ExtractAsDangling());
return;
}
WGPUAdapter adapter = ToAPI(ReturnToAPI(std::move(mAdapter)));
if (adapter == nullptr) {
mCallback(WGPURequestAdapterStatus_Unavailable, nullptr, "No supported adapters",
- mUserdata.ExtractAsDangling());
+ mUserdata1.ExtractAsDangling(), mUserdata2.ExtractAsDangling());
} else {
mCallback(WGPURequestAdapterStatus_Success, adapter, nullptr,
- mUserdata.ExtractAsDangling());
+ mUserdata1.ExtractAsDangling(), mUserdata2.ExtractAsDangling());
}
}
};
@@ -357,7 +372,8 @@
const RequestAdapterOptions* options) {
static constexpr RequestAdapterOptions kDefaultOptions = {};
if (options == nullptr) {
- // Default path that returns all WebGPU core adapters on the system with default toggles.
+ // Default path that returns all WebGPU core adapters on the system with default
+ // toggles.
return EnumerateAdapters(&kDefaultOptions);
}
diff --git a/src/dawn/native/Instance.h b/src/dawn/native/Instance.h
index 7aeda22..f6c7d2c 100644
--- a/src/dawn/native/Instance.h
+++ b/src/dawn/native/Instance.h
@@ -82,6 +82,8 @@
void* userdata);
Future APIRequestAdapterF(const RequestAdapterOptions* options,
const RequestAdapterCallbackInfo& callbackInfo);
+ Future APIRequestAdapter2(const RequestAdapterOptions* options,
+ const WGPURequestAdapterCallbackInfo2& callbackInfo);
// Discovers and returns a vector of adapters.
// All systems adapters that can be found are returned if no options are passed.
diff --git a/src/dawn/tests/DawnTest.cpp b/src/dawn/tests/DawnTest.cpp
index cb9bfcb..0fbfa98 100644
--- a/src/dawn/tests/DawnTest.cpp
+++ b/src/dawn/tests/DawnTest.cpp
@@ -724,9 +724,10 @@
// Override procs to provide harness-specific behavior to always select the adapter required in
// testing parameter, and to allow fixture-specific overriding of the test device with
// CreateDeviceImpl.
- procs.instanceRequestAdapter = [](WGPUInstance cInstance, const WGPURequestAdapterOptions*,
- WGPURequestAdapterCallback callback, void* userdata) {
+ procs.instanceRequestAdapter2 = [](WGPUInstance cInstance, const WGPURequestAdapterOptions*,
+ WGPURequestAdapterCallbackInfo2 callbackInfo) -> WGPUFuture {
DAWN_ASSERT(gCurrentTest);
+ DAWN_ASSERT(callbackInfo.mode == WGPUCallbackMode_AllowSpontaneous);
// Use the required toggles of test case when creating adapter.
const auto& enabledToggles = gCurrentTest->mParam.forceEnabledWorkarounds;
@@ -762,7 +763,11 @@
WGPUAdapter cAdapter = it->Get();
DAWN_ASSERT(cAdapter);
native::GetProcs().adapterAddRef(cAdapter);
- callback(WGPURequestAdapterStatus_Success, cAdapter, nullptr, userdata);
+ callbackInfo.callback(WGPURequestAdapterStatus_Success, cAdapter, nullptr,
+ callbackInfo.userdata1, callbackInfo.userdata2);
+
+ // Returning a placeholder future that we should never be waiting on.
+ return {0};
};
procs.adapterRequestDevice = [](WGPUAdapter cAdapter, const WGPUDeviceDescriptor* descriptor,
@@ -1189,10 +1194,9 @@
// RequestAdapter is overriden to ignore RequestAdapterOptions, and select based on test params.
instance.RequestAdapter(
- nullptr,
- [](WGPURequestAdapterStatus, WGPUAdapter cAdapter, const char*, void* userdata) {
- *static_cast<wgpu::Adapter*>(userdata) = wgpu::Adapter::Acquire(cAdapter);
- },
+ nullptr, wgpu::CallbackMode::AllowSpontaneous,
+ [](wgpu::RequestAdapterStatus status, wgpu::Adapter result, char const* message,
+ wgpu::Adapter* userdata) -> void { *userdata = std::move(result); },
&adapter);
FlushWire();
DAWN_ASSERT(adapter);
diff --git a/src/dawn/tests/unittests/validation/ValidationTest.cpp b/src/dawn/tests/unittests/validation/ValidationTest.cpp
index fc2432c..21328b4 100644
--- a/src/dawn/tests/unittests/validation/ValidationTest.cpp
+++ b/src/dawn/tests/unittests/validation/ValidationTest.cpp
@@ -106,21 +106,23 @@
// Forward to dawn::native instanceRequestAdapter, but save the returned adapter in
// gCurrentTest->mBackendAdapter.
- procs.instanceRequestAdapter = [](WGPUInstance i, const WGPURequestAdapterOptions* options,
- WGPURequestAdapterCallback callback, void* userdata) {
+ procs.instanceRequestAdapter2 = [](WGPUInstance i, const WGPURequestAdapterOptions* options,
+ WGPURequestAdapterCallbackInfo2 callbackInfo) -> WGPUFuture {
DAWN_ASSERT(gCurrentTest);
- dawn::native::GetProcs().instanceRequestAdapter(
- i, options,
- [](WGPURequestAdapterStatus status, WGPUAdapter cAdapter, char const* message,
- void* userdata) {
- gCurrentTest->mBackendAdapter = dawn::native::FromAPI(cAdapter);
+ DAWN_ASSERT(callbackInfo.mode == WGPUCallbackMode_AllowSpontaneous);
- auto* callbackAndUserdata =
- static_cast<std::pair<WGPURequestAdapterCallback, void*>*>(userdata);
- callbackAndUserdata->first(status, cAdapter, message, callbackAndUserdata->second);
- delete callbackAndUserdata;
- },
- new std::pair<WGPURequestAdapterCallback, void*>(callback, userdata));
+ return dawn::native::GetProcs().instanceRequestAdapter2(
+ i, options,
+ {nullptr, WGPUCallbackMode_AllowSpontaneous,
+ [](WGPURequestAdapterStatus status, WGPUAdapter cAdapter, char const* message,
+ void* userdata, void*) {
+ gCurrentTest->mBackendAdapter = dawn::native::FromAPI(cAdapter);
+
+ auto* info = static_cast<WGPURequestAdapterCallbackInfo2*>(userdata);
+ info->callback(status, cAdapter, message, info->userdata1, info->userdata2);
+ delete info;
+ },
+ new WGPURequestAdapterCallbackInfo2(callbackInfo), nullptr});
};
procs.adapterRequestDevice = [](WGPUAdapter self, const WGPUDeviceDescriptor* descriptor,
@@ -284,10 +286,9 @@
void ValidationTest::CreateTestAdapter(wgpu::RequestAdapterOptions options) {
instance.RequestAdapter(
- &options,
- [](WGPURequestAdapterStatus, WGPUAdapter cAdapter, const char*, void* userdata) {
- *static_cast<wgpu::Adapter*>(userdata) = wgpu::Adapter::Acquire(cAdapter);
- },
+ &options, wgpu::CallbackMode::AllowSpontaneous,
+ [](wgpu::RequestAdapterStatus status, wgpu::Adapter result, char const* message,
+ wgpu::Adapter* userdata) -> void { *userdata = std::move(result); },
&adapter);
FlushWire();
}
diff --git a/src/dawn/tests/unittests/wire/WireTest.cpp b/src/dawn/tests/unittests/wire/WireTest.cpp
index 730ba13..63c3552 100644
--- a/src/dawn/tests/unittests/wire/WireTest.cpp
+++ b/src/dawn/tests/unittests/wire/WireTest.cpp
@@ -91,10 +91,11 @@
// Create the adapter for testing.
apiAdapter = api.GetNewAdapter();
WGPURequestAdapterOptions adapterOpts = {};
- MockCallback<WGPURequestAdapterCallback> adapterCb;
- wgpuInstanceRequestAdapter(instance, &adapterOpts, adapterCb.Callback(),
- adapterCb.MakeUserdata(this));
- EXPECT_CALL(api, OnInstanceRequestAdapter(apiInstance, NotNull(), _)).WillOnce([&]() {
+ MockCallback<WGPURequestAdapterCallback2> adapterCb;
+ wgpuInstanceRequestAdapter2(instance, &adapterOpts,
+ {nullptr, WGPUCallbackMode_AllowSpontaneous, adapterCb.Callback(),
+ nullptr, adapterCb.MakeUserdata(this)});
+ EXPECT_CALL(api, OnInstanceRequestAdapter2(apiInstance, NotNull(), _)).WillOnce([&]() {
EXPECT_CALL(api, AdapterHasFeature(apiAdapter, _)).WillRepeatedly(Return(false));
EXPECT_CALL(api, AdapterGetProperties(apiAdapter, NotNull()))
@@ -116,12 +117,13 @@
.WillOnce(Return(0))
.WillOnce(Return(0));
- api.CallInstanceRequestAdapterCallback(apiInstance, WGPURequestAdapterStatus_Success,
- apiAdapter, nullptr);
+ api.CallInstanceRequestAdapter2Callback(apiInstance, WGPURequestAdapterStatus_Success,
+ apiAdapter, nullptr);
});
FlushClient();
WGPUAdapter cAdapter = nullptr;
- EXPECT_CALL(adapterCb, Call(WGPURequestAdapterStatus_Success, NotNull(), nullptr, this))
+ EXPECT_CALL(adapterCb,
+ Call(WGPURequestAdapterStatus_Success, NotNull(), nullptr, nullptr, this))
.WillOnce(SaveArg<1>(&cAdapter));
FlushServer();
EXPECT_NE(cAdapter, nullptr);
diff --git a/src/dawn/wire/client/Instance.cpp b/src/dawn/wire/client/Instance.cpp
index 8a764d1..c4ed227 100644
--- a/src/dawn/wire/client/Instance.cpp
+++ b/src/dawn/wire/client/Instance.cpp
@@ -51,7 +51,14 @@
RequestAdapterEvent(const WGPURequestAdapterCallbackInfo& callbackInfo, Adapter* adapter)
: TrackedEvent(callbackInfo.mode),
mCallback(callbackInfo.callback),
- mUserdata(callbackInfo.userdata),
+ mUserdata1(callbackInfo.userdata),
+ mAdapter(adapter) {}
+
+ RequestAdapterEvent(const WGPURequestAdapterCallbackInfo2& callbackInfo, Adapter* adapter)
+ : TrackedEvent(callbackInfo.mode),
+ mCallback2(callbackInfo.callback),
+ mUserdata1(callbackInfo.userdata1),
+ mUserdata2(callbackInfo.userdata2),
mAdapter(adapter) {}
EventType GetType() override { return kType; }
@@ -78,10 +85,11 @@
private:
void CompleteImpl(FutureID futureID, EventCompletionType completionType) override {
- if (mCallback == nullptr) {
+ if (mCallback == nullptr && mCallback2 == nullptr) {
// If there's no callback, just clean up the resources.
mAdapter.ExtractAsDangling()->Release();
- mUserdata.ExtractAsDangling();
+ mUserdata1.ExtractAsDangling();
+ mUserdata2.ExtractAsDangling();
return;
}
@@ -91,12 +99,22 @@
}
Adapter* adapter = mAdapter.ExtractAsDangling();
- mCallback(mStatus, ToAPI(mStatus == WGPURequestAdapterStatus_Success ? adapter : nullptr),
- mMessage ? mMessage->c_str() : nullptr, mUserdata.ExtractAsDangling());
+ if (mCallback) {
+ mCallback(mStatus,
+ ToAPI(mStatus == WGPURequestAdapterStatus_Success ? adapter : nullptr),
+ mMessage ? mMessage->c_str() : nullptr, mUserdata1.ExtractAsDangling());
+ } else {
+ mCallback2(mStatus,
+ ToAPI(mStatus == WGPURequestAdapterStatus_Success ? adapter : nullptr),
+ mMessage ? mMessage->c_str() : nullptr, mUserdata1.ExtractAsDangling(),
+ mUserdata2.ExtractAsDangling());
+ }
}
- WGPURequestAdapterCallback mCallback;
- raw_ptr<void> mUserdata;
+ WGPURequestAdapterCallback mCallback = nullptr;
+ WGPURequestAdapterCallback2 mCallback2 = nullptr;
+ raw_ptr<void> mUserdata1;
+ raw_ptr<void> mUserdata2;
// Note that the message is optional because we want to return nullptr when it wasn't set
// instead of a pointer to an empty string.
@@ -198,6 +216,29 @@
cmd.future = {futureIDInternal};
cmd.adapterObjectHandle = adapter->GetWireHandle();
cmd.options = options;
+ cmd.userdataCount = 1;
+
+ client->SerializeCommand(cmd);
+ return {futureIDInternal};
+}
+
+WGPUFuture Instance::RequestAdapter2(const WGPURequestAdapterOptions* options,
+ const WGPURequestAdapterCallbackInfo2& callbackInfo) {
+ Client* client = GetClient();
+ Adapter* adapter = client->Make<Adapter>(GetEventManagerHandle());
+ auto [futureIDInternal, tracked] =
+ GetEventManager().TrackEvent(std::make_unique<RequestAdapterEvent>(callbackInfo, adapter));
+ if (!tracked) {
+ return {futureIDInternal};
+ }
+
+ InstanceRequestAdapterCmd cmd;
+ cmd.instanceId = GetWireId();
+ cmd.eventManagerHandle = GetEventManagerHandle();
+ cmd.future = {futureIDInternal};
+ cmd.adapterObjectHandle = adapter->GetWireHandle();
+ cmd.options = options;
+ cmd.userdataCount = 2;
client->SerializeCommand(cmd);
return {futureIDInternal};
diff --git a/src/dawn/wire/client/Instance.h b/src/dawn/wire/client/Instance.h
index 466385c..c309047 100644
--- a/src/dawn/wire/client/Instance.h
+++ b/src/dawn/wire/client/Instance.h
@@ -49,6 +49,8 @@
void* userdata);
WGPUFuture RequestAdapterF(const WGPURequestAdapterOptions* options,
const WGPURequestAdapterCallbackInfo& callbackInfo);
+ WGPUFuture RequestAdapter2(const WGPURequestAdapterOptions* options,
+ const WGPURequestAdapterCallbackInfo2& callbackInfo);
void ProcessEvents();
WGPUWaitStatus WaitAny(size_t count, WGPUFutureWaitInfo* infos, uint64_t timeoutNS);
diff --git a/src/dawn/wire/server/Server.h b/src/dawn/wire/server/Server.h
index 2d29945..a64d59f 100644
--- a/src/dawn/wire/server/Server.h
+++ b/src/dawn/wire/server/Server.h
@@ -81,6 +81,7 @@
template <typename Return, typename Class, typename Userdata, typename... Args>
struct ExtractedTypes<Return (Class::*)(Userdata*, Args...)> {
using UntypedCallback = Return (*)(Args..., void*);
+ using UntypedCallback2 = Return (*)(Args..., void*, void*);
static Return Callback(Args... args, void* userdata) {
// Acquire the userdata, and cast it to UserdataT.
std::unique_ptr<Userdata> data(static_cast<Userdata*>(userdata));
@@ -92,15 +93,31 @@
// Forward the arguments and the typed userdata to the Server:: member function.
(server.get()->*F)(data.get(), std::forward<decltype(args)>(args)...);
}
+ static Return Callback2(Args... args, void* userdata, void*) {
+ // Acquire the userdata, and cast it to UserdataT.
+ std::unique_ptr<Userdata> data(static_cast<Userdata*>(userdata));
+ auto server = data->server.lock();
+ if (!server) {
+ // Do nothing if the server has already been destroyed.
+ return;
+ }
+ // Forward the arguments and the typed userdata to the Server:: member function.
+ (server.get()->*F)(data.get(), std::forward<decltype(args)>(args)...);
+ }
};
static constexpr typename ExtractedTypes<decltype(F)>::UntypedCallback Create() {
return ExtractedTypes<decltype(F)>::Callback;
}
+ static constexpr typename ExtractedTypes<decltype(F)>::UntypedCallback2 Create2() {
+ return ExtractedTypes<decltype(F)>::Callback2;
+ }
};
template <auto F>
constexpr auto ForwardToServer = ForwardToServerHelper<F>::Create();
+template <auto F>
+constexpr auto ForwardToServer2 = ForwardToServerHelper<F>::Create2();
struct MapUserdata : CallbackUserdata {
using CallbackUserdata::CallbackUserdata;
diff --git a/src/dawn/wire/server/ServerInstance.cpp b/src/dawn/wire/server/ServerInstance.cpp
index e47c09f..a204691 100644
--- a/src/dawn/wire/server/ServerInstance.cpp
+++ b/src/dawn/wire/server/ServerInstance.cpp
@@ -38,7 +38,8 @@
ObjectHandle eventManager,
WGPUFuture future,
ObjectHandle adapterHandle,
- const WGPURequestAdapterOptions* options) {
+ const WGPURequestAdapterOptions* options,
+ uint8_t userdataCount) {
Reserved<WGPUAdapter> adapter;
WIRE_TRY(Objects<WGPUAdapter>().Allocate(&adapter, adapterHandle, AllocationState::Reserved));
@@ -47,9 +48,16 @@
userdata->future = future;
userdata->adapterObjectId = adapter.id;
- mProcs.instanceRequestAdapter(instance->handle, options,
- ForwardToServer<&Server::OnRequestAdapterCallback>,
- userdata.release());
+ if (userdataCount == 1) {
+ mProcs.instanceRequestAdapter(instance->handle, options,
+ ForwardToServer<&Server::OnRequestAdapterCallback>,
+ userdata.release());
+ } else {
+ mProcs.instanceRequestAdapter2(
+ instance->handle, options,
+ {nullptr, WGPUCallbackMode_AllowSpontaneous,
+ ForwardToServer2<&Server::OnRequestAdapterCallback>, userdata.release(), nullptr});
+ }
return WireResult::Success;
}