WireCmd: require opt-in to treat ID 0 as nullptr instead of error.
In preparation for the descriptorization of BindGroup, support was added
to treat wire ID 0 as nullptr for a bunch of objects. Now that we have a
fuzzer for the wire+frontend, we need to validate when we have a 0 id.
Either the wire needs to reject the ID or the frontend needs to validate
against nullptrs. Since only few entrypoints will have a use for
nullptrs (bind groups, render pass resolve textures), we require an
opt-in in the JSON file for a structure member or an argument to be
optional.
This disables the tests related to ID 0 = nullptr, because we don't yet
have optional argument/members in dawn.json.
BUG=chromium:905273
BUG=chromium:906418
BUG=chromium:908678
Change-Id: If9a3c4857db43ca26a90abff2437e1cebb0ab79b
Reviewed-on: https://dawn-review.googlesource.com/c/2704
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Stephen White <senorblanco@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
diff --git a/generator/main.py b/generator/main.py
index 9f27d58..e58daac 100644
--- a/generator/main.py
+++ b/generator/main.py
@@ -79,11 +79,12 @@
Type.__init__(self, name, record)
class MethodArgument:
- def __init__(self, name, typ, annotation):
+ def __init__(self, name, typ, annotation, optional):
self.name = name
self.type = typ
self.annotation = annotation
self.length = None
+ self.optional = optional
Method = namedtuple('Method', ['name', 'return_type', 'arguments'])
class ObjectType(Type):
@@ -94,11 +95,12 @@
self.built_type = None
class StructureMember:
- def __init__(self, name, typ, annotation):
+ def __init__(self, name, typ, annotation, optional):
self.name = name
self.type = typ
self.annotation = annotation
self.length = None
+ self.optional = optional
class StructureType(Type):
def __init__(self, name, record):
@@ -120,7 +122,8 @@
arguments = []
arguments_by_name = {}
for a in record.get('args', []):
- arg = MethodArgument(Name(a['name']), types[a['type']], a.get('annotation', 'value'))
+ arg = MethodArgument(Name(a['name']), types[a['type']],
+ a.get('annotation', 'value'), a.get('optional', False))
arguments.append(arg)
arguments_by_name[arg.name.canonical_case()] = arg
@@ -153,7 +156,8 @@
def link_structure(struct, types):
def make_member(m):
- return StructureMember(Name(m['name']), types[m['type']], m.get('annotation', 'value'))
+ return StructureMember(Name(m['name']), types[m['type']],
+ m.get('annotation', 'value'), m.get('optional', False))
members = []
members_by_name = {}
@@ -410,9 +414,9 @@
if typ.is_builder:
methods.append(Method(Name('set error callback'), types['void'], [
- MethodArgument(Name('callback'), types['builder error callback'], 'value'),
- MethodArgument(Name('userdata1'), types['callback userdata'], 'value'),
- MethodArgument(Name('userdata2'), types['callback userdata'], 'value'),
+ MethodArgument(Name('callback'), types['builder error callback'], 'value', False),
+ MethodArgument(Name('userdata1'), types['callback userdata'], 'value', False),
+ MethodArgument(Name('userdata2'), types['callback userdata'], 'value', False),
]))
return methods
diff --git a/generator/templates/dawn_wire/WireClient.cpp b/generator/templates/dawn_wire/WireClient.cpp
index 6641a14..117cfad 100644
--- a/generator/templates/dawn_wire/WireClient.cpp
+++ b/generator/templates/dawn_wire/WireClient.cpp
@@ -214,11 +214,14 @@
// Implementation of the ObjectIdProvider interface
{% for type in by_category["object"] %}
- ObjectId GetId({{as_cType(type.name)}} object) const override {
+ ObjectId GetId({{as_cType(type.name)}} object) const final {
+ return reinterpret_cast<{{as_wireType(type)}}>(object)->id;
+ }
+ ObjectId GetOptionalId({{as_cType(type.name)}} object) const final {
if (object == nullptr) {
return 0;
}
- return reinterpret_cast<{{as_wireType(type)}}>(object)->id;
+ return GetId(object);
}
{% endfor %}
diff --git a/generator/templates/dawn_wire/WireCmd.cpp b/generator/templates/dawn_wire/WireCmd.cpp
index a72c9f4..ceb828f 100644
--- a/generator/templates/dawn_wire/WireCmd.cpp
+++ b/generator/templates/dawn_wire/WireCmd.cpp
@@ -48,7 +48,8 @@
//* Outputs the serialization code to put `in` in `out`
{% macro serialize_member(member, in, out) %}
{%- if member.type.category == "object" -%}
- {{out}} = provider.GetId({{in}});
+ {% set Optional = "Optional" if member.optional else "" %}
+ {{out}} = provider.Get{{Optional}}Id({{in}});
{% elif member.type.category == "structure"%}
{{as_cType(member.type.name)}}Serialize({{in}}, &{{out}}, buffer, provider);
{%- else -%}
@@ -59,7 +60,8 @@
//* Outputs the deserialization code to put `in` in `out`
{% macro deserialize_member(member, in, out) %}
{%- if member.type.category == "object" -%}
- DESERIALIZE_TRY(resolver.GetFromId({{in}}, &{{out}}));
+ {% set Optional = "Optional" if member.optional else "" %}
+ DESERIALIZE_TRY(resolver.Get{{Optional}}FromId({{in}}, &{{out}}));
{% elif member.type.category == "structure"%}
DESERIALIZE_TRY({{as_cType(member.type.name)}}Deserialize(&{{out}}, &{{in}}, buffer, size, allocator, resolver));
{%- else -%}
diff --git a/generator/templates/dawn_wire/WireCmd.h b/generator/templates/dawn_wire/WireCmd.h
index 9e451ef..2d53008 100644
--- a/generator/templates/dawn_wire/WireCmd.h
+++ b/generator/templates/dawn_wire/WireCmd.h
@@ -40,6 +40,7 @@
public:
{% for type in by_category["object"] %}
virtual DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0;
+ virtual DeserializeResult GetOptionalFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0;
{% endfor %}
};
@@ -48,6 +49,7 @@
public:
{% for type in by_category["object"] %}
virtual ObjectId GetId({{as_cType(type.name)}} object) const = 0;
+ virtual ObjectId GetOptionalId({{as_cType(type.name)}} object) const = 0;
{% endfor %}
};
diff --git a/generator/templates/dawn_wire/WireServer.cpp b/generator/templates/dawn_wire/WireServer.cpp
index 945f565..ebc3417 100644
--- a/generator/templates/dawn_wire/WireServer.cpp
+++ b/generator/templates/dawn_wire/WireServer.cpp
@@ -357,12 +357,7 @@
// Implementation of the ObjectIdResolver interface
{% for type in by_category["object"] %}
- DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const override {
- if (id == 0) {
- *out = nullptr;
- return DeserializeResult::Success;
- }
-
+ DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const final {
auto data = mKnown{{type.name.CamelCase()}}.Get(id);
if (data == nullptr) {
return DeserializeResult::FatalError;
@@ -375,6 +370,15 @@
return DeserializeResult::ErrorObject;
}
}
+
+ DeserializeResult GetOptionalFromId(ObjectId id, {{as_cType(type.name)}}* out) const final {
+ if (id == 0) {
+ *out = nullptr;
+ return DeserializeResult::Success;
+ }
+
+ return GetFromId(id, out);
+ }
{% endfor %}
//* The list of known IDs for each object type.
diff --git a/src/tests/unittests/WireTests.cpp b/src/tests/unittests/WireTests.cpp
index 731ee96..4670cab 100644
--- a/src/tests/unittests/WireTests.cpp
+++ b/src/tests/unittests/WireTests.cpp
@@ -329,32 +329,27 @@
// Test that the wire is able to send objects as value arguments
TEST_F(WireTests, ObjectAsValueArgument) {
- // Create pipeline
- dawnComputePipelineDescriptor pipelineDesc;
- pipelineDesc.nextInChain = nullptr;
- pipelineDesc.layout = nullptr;
- pipelineDesc.entryPoint = "main";
- pipelineDesc.module = nullptr;
- dawnComputePipeline pipeline = dawnDeviceCreateComputePipeline(device, &pipelineDesc);
+ // Create a RenderPassDescriptor
+ dawnRenderPassDescriptorBuilder renderPassBuilder = dawnDeviceCreateRenderPassDescriptorBuilder(device);
+ dawnRenderPassDescriptor renderPass = dawnRenderPassDescriptorBuilderGetResult(renderPassBuilder);
- dawnComputePipeline apiPipeline = api.GetNewComputePipeline();
- EXPECT_CALL(api, DeviceCreateComputePipeline(apiDevice, _))
- .WillOnce(Return(apiPipeline));
+ dawnRenderPassDescriptorBuilder apiRenderPassBuilder = api.GetNewRenderPassDescriptorBuilder();
+ EXPECT_CALL(api, DeviceCreateRenderPassDescriptorBuilder(apiDevice))
+ .WillOnce(Return(apiRenderPassBuilder));
+ dawnRenderPassDescriptor apiRenderPass = api.GetNewRenderPassDescriptor();
+ EXPECT_CALL(api, RenderPassDescriptorBuilderGetResult(apiRenderPassBuilder))
+ .WillOnce(Return(apiRenderPass));
- // Create command buffer builder, setting pipeline
+ // Create command buffer builder, setting render pass descriptor
dawnCommandBufferBuilder cmdBufBuilder = dawnDeviceCreateCommandBufferBuilder(device);
- dawnComputePassEncoder pass = dawnCommandBufferBuilderBeginComputePass(cmdBufBuilder);
- dawnComputePassEncoderSetComputePipeline(pass, pipeline);
+ dawnCommandBufferBuilderBeginRenderPass(cmdBufBuilder, renderPass);
dawnCommandBufferBuilder apiCmdBufBuilder = api.GetNewCommandBufferBuilder();
EXPECT_CALL(api, DeviceCreateCommandBufferBuilder(apiDevice))
.WillOnce(Return(apiCmdBufBuilder));
- dawnComputePassEncoder apiPass = api.GetNewComputePassEncoder();
- EXPECT_CALL(api, CommandBufferBuilderBeginComputePass(apiCmdBufBuilder))
- .WillOnce(Return(apiPass));
-
- EXPECT_CALL(api, ComputePassEncoderSetComputePipeline(apiPass, apiPipeline));
+ EXPECT_CALL(api, CommandBufferBuilderBeginRenderPass(apiCmdBufBuilder, apiRenderPass))
+ .Times(1);
FlushClient();
}
@@ -486,7 +481,7 @@
}
// Test passing nullptr instead of objects - object as value version
-TEST_F(WireTests, NullptrAsValue) {
+TEST_F(WireTests, DISABLED_NullptrAsValue) {
dawnCommandBufferBuilder builder = dawnDeviceCreateCommandBufferBuilder(device);
dawnComputePassEncoder pass = dawnCommandBufferBuilderBeginComputePass(builder);
dawnComputePassEncoderSetComputePipeline(pass, nullptr);
@@ -506,7 +501,7 @@
}
// Test passing nullptr instead of objects - array of objects version
-TEST_F(WireTests, NullptrInArray) {
+TEST_F(WireTests, DISABLED_NullptrInArray) {
dawnBindGroupLayout nullBGL = nullptr;
dawnPipelineLayoutDescriptor descriptor;