Make codegen not cast between fnptrs.
When generating the proc tables for dawn_native and dawn_wire (for the
client), we were casting between function pointers with the C types and
function pointers with the internal types. This is UB and was caught by
UBSan.
Replace casts between function pointers by casts between types inside
the functions themselves.
BUG=chromium:906349
Change-Id: Icd8f6eedfa729e767ae3bacb2d6951f5ad5c4c82
Reviewed-on: https://dawn-review.googlesource.com/c/2400
Reviewed-by: Stephen White <senorblanco@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/generator/templates/dawn_native/ProcTable.cpp b/generator/templates/dawn_native/ProcTable.cpp
index e3f9d35..21c2926 100644
--- a/generator/templates/dawn_native/ProcTable.cpp
+++ b/generator/templates/dawn_native/ProcTable.cpp
@@ -81,16 +81,30 @@
}
//* Entry point with validation
- {{as_frontendType(method.return_type)}} Validating{{suffix}}(
- {{-as_frontendType(type)}} self
+ {{as_cType(method.return_type.name)}} Validating{{suffix}}(
+ {{-as_cType(type.name)}} cSelf
{%- for arg in method.arguments -%}
- , {{as_annotated_frontendType(arg)}}
+ , {{as_annotated_cType(arg)}}
{%- endfor -%}
) {
+ //* Perform conversion between C types and frontend types
+ auto self = reinterpret_cast<{{as_frontendType(type)}}>(cSelf);
+
+ {% for arg in method.arguments %}
+ {% set varName = as_varName(arg.name) %}
+ {% if arg.type.category in ["enum", "bitmask"] %}
+ auto {{varName}}_ = static_cast<{{as_frontendType(arg.type)}}>({{varName}});
+ {% elif arg.annotation != "value" or arg.type.category == "object" %}
+ auto {{varName}}_ = reinterpret_cast<{{decorate("", as_frontendType(arg.type), arg)}}>({{varName}});
+ {% else %}
+ auto {{varName}}_ = {{as_varName(arg.name)}};
+ {% endif %}
+ {%- endfor-%}
+
//* Do the autogenerated checks
bool valid = ValidateBase{{suffix}}(self
{%- for arg in method.arguments -%}
- , {{as_varName(arg.name)}}
+ , {{as_varName(arg.name)}}_
{%- endfor -%}
);
@@ -100,7 +114,7 @@
if (valid) {
MaybeError error = self->Validate{{method.name.CamelCase()}}(
{%- for arg in method.arguments -%}
- {% if not loop.first %}, {% endif %}{{as_varName(arg.name)}}
+ {% if not loop.first %}, {% endif %}{{as_varName(arg.name)}}_
{%- endfor -%}
);
//* Builders want to handle error themselves, unpack the error and make
@@ -141,17 +155,11 @@
self->{{method.name.CamelCase()}}(
{%- for arg in method.arguments -%}
{%- if not loop.first %}, {% endif -%}
- {%- if arg.type.category in ["enum", "bitmask"] -%}
- static_cast<dawn::{{as_cppType(arg.type.name)}}>({{as_varName(arg.name)}})
- {%- elif arg.type.category == "structure" and arg.annotation != "value" -%}
- reinterpret_cast<const {{as_cppType(arg.type.name)}}*>({{as_varName(arg.name)}})
- {%- else -%}
- {{as_varName(arg.name)}}
- {%- endif -%}
+ {{as_varName(arg.name)}}_
{%- endfor -%}
);
{% if method.return_type.name.canonical_case() != "void" %}
- return reinterpret_cast<{{as_frontendType(method.return_type)}}>(result);
+ return reinterpret_cast<{{as_cType(method.return_type.name)}}>(result);
{% endif %}
}
{% endfor %}
@@ -162,7 +170,7 @@
dawnProcTable table;
{% for type in by_category["object"] %}
{% for method in native_methods(type) %}
- table.{{as_varName(type.name, method.name)}} = reinterpret_cast<{{as_cProc(type.name, method.name)}}>(Validating{{as_MethodSuffix(type.name, method.name)}});
+ table.{{as_varName(type.name, method.name)}} = Validating{{as_MethodSuffix(type.name, method.name)}};
{% endfor %}
{% endfor %}
return table;
diff --git a/generator/templates/dawn_wire/WireClient.cpp b/generator/templates/dawn_wire/WireClient.cpp
index 4dd37eb..6641a14 100644
--- a/generator/templates/dawn_wire/WireClient.cpp
+++ b/generator/templates/dawn_wire/WireClient.cpp
@@ -238,12 +238,13 @@
//* Implementation of the client API functions.
{% for type in by_category["object"] %}
{% set Type = type.name.CamelCase() %}
+ {% set cType = as_cType(type.name) %}
{% for method in type.methods %}
{% set Suffix = as_MethodSuffix(type.name, method.name) %}
- {{as_wireType(method.return_type)}} Client{{Suffix}}(
- {{-as_cType(type.name)}} cSelf
+ {{as_cType(method.return_type.name)}} Client{{Suffix}}(
+ {{-cType}} cSelf
{%- for arg in method.arguments -%}
, {{as_annotated_cType(arg)}}
{%- endfor -%}
@@ -282,16 +283,17 @@
cmd.Serialize(allocatedBuffer, *device);
{% if method.return_type.category == "object" %}
- return allocation->object.get();
+ return reinterpret_cast<{{as_cType(method.return_type.name)}}>(allocation->object.get());
{% endif %}
}
{% endfor %}
{% if type.is_builder %}
- void Client{{as_MethodSuffix(type.name, Name("set error callback"))}}({{Type}}* self,
+ void Client{{as_MethodSuffix(type.name, Name("set error callback"))}}({{cType}} cSelf,
dawnBuilderErrorCallback callback,
dawnCallbackUserdata userdata1,
dawnCallbackUserdata userdata2) {
+ {{Type}}* self = reinterpret_cast<{{Type}}*>(cSelf);
self->builderCallback.callback = callback;
self->builderCallback.userdata1 = userdata1;
self->builderCallback.userdata2 = userdata2;
@@ -300,7 +302,8 @@
{% if not type.name.canonical_case() == "device" %}
//* When an object's refcount reaches 0, notify the server side of it and delete it.
- void Client{{as_MethodSuffix(type.name, Name("release"))}}({{Type}}* obj) {
+ void Client{{as_MethodSuffix(type.name, Name("release"))}}({{cType}} cObj) {
+ {{Type}}* obj = reinterpret_cast<{{Type}}*>(cObj);
obj->refcount --;
if (obj->refcount > 0) {
@@ -318,13 +321,16 @@
obj->device->{{type.name.camelCase()}}.Free(obj);
}
- void Client{{as_MethodSuffix(type.name, Name("reference"))}}({{Type}}* obj) {
+ void Client{{as_MethodSuffix(type.name, Name("reference"))}}({{cType}} cObj) {
+ {{Type}}* obj = reinterpret_cast<{{Type}}*>(cObj);
obj->refcount ++;
}
{% endif %}
{% endfor %}
- void ClientBufferMapReadAsync(Buffer* buffer, uint32_t start, uint32_t size, dawnBufferMapReadCallback callback, dawnCallbackUserdata userdata) {
+ void ClientBufferMapReadAsync(dawnBuffer cBuffer, uint32_t start, uint32_t size, dawnBufferMapReadCallback callback, dawnCallbackUserdata userdata) {
+ Buffer* buffer = reinterpret_cast<Buffer*>(cBuffer);
+
uint32_t serial = buffer->requestSerial++;
ASSERT(buffer->requests.find(serial) == buffer->requests.end());
@@ -346,7 +352,9 @@
*allocCmd = cmd;
}
- void ClientBufferMapWriteAsync(Buffer* buffer, uint32_t start, uint32_t size, dawnBufferMapWriteCallback callback, dawnCallbackUserdata userdata) {
+ void ClientBufferMapWriteAsync(dawnBuffer cBuffer, uint32_t start, uint32_t size, dawnBufferMapWriteCallback callback, dawnCallbackUserdata userdata) {
+ Buffer* buffer = reinterpret_cast<Buffer*>(cBuffer);
+
uint32_t serial = buffer->requestSerial++;
ASSERT(buffer->requests.find(serial) == buffer->requests.end());
@@ -401,13 +409,14 @@
ClientBufferUnmap(cBuffer);
}
- void ClientDeviceReference(Device*) {
+ void ClientDeviceReference(dawnDevice) {
}
- void ClientDeviceRelease(Device*) {
+ void ClientDeviceRelease(dawnDevice) {
}
- void ClientDeviceSetErrorCallback(Device* self, dawnDeviceErrorCallback callback, dawnCallbackUserdata userdata) {
+ void ClientDeviceSetErrorCallback(dawnDevice cSelf, dawnDeviceErrorCallback callback, dawnCallbackUserdata userdata) {
+ Device* self = reinterpret_cast<Device*>(cSelf);
self->errorCallback = callback;
self->errorUserdata = userdata;
}
@@ -425,9 +434,9 @@
{% for method in native_methods(type) %}
{% set suffix = as_MethodSuffix(type.name, method.name) %}
{% if suffix in proxied_commands %}
- table.{{as_varName(type.name, method.name)}} = reinterpret_cast<{{as_cProc(type.name, method.name)}}>(ProxyClient{{suffix}});
+ table.{{as_varName(type.name, method.name)}} = ProxyClient{{suffix}};
{% else %}
- table.{{as_varName(type.name, method.name)}} = reinterpret_cast<{{as_cProc(type.name, method.name)}}>(Client{{suffix}});
+ table.{{as_varName(type.name, method.name)}} = Client{{suffix}};
{% endif %}
{% endfor %}
{% endfor %}