Add helper functions to iterate over ChainedStructs
This CL adds two helpers for more ergonomic processing of
ChainedStructs.
1. FindInChain(): Iterates through the chain and automatically
casts the ChainedStruct into the appropriate child type before
returning.
2. ValidateSTypes(): Verifies that the chain only contains structs
with sTypes from a pre-defined set. This also allows the caller
to specify one-of constraints.
3. ValidateSingleSType(): Verifies that the chain contains a
single struct with a specific sType or is an empty chain. This
is a common case of |ValidateSTypes()| and is separated out as
a fast-path.
Change-Id: I938df0bf2a9b1800b1105fb7f80fbde20bef8ec8
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/47680
Commit-Queue: Brian Ho <hob@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/docs/codegen.md b/docs/codegen.md
index 9cdf40b..d79c0e9 100644
--- a/docs/codegen.md
+++ b/docs/codegen.md
@@ -18,6 +18,7 @@
- validation helper functions for dawn_native
- the definition of dawn_native's proc table
- dawn_native's internal version of the webgpu.h types
+ - utilities for working with dawn_native's chained structs
- a lot of dawn_wire parts, see below
Internally `dawn.json` is a dictionary from the "canonical name" of things to their definition. The "canonical name" is a space-separated (mostly) lower-case version of the name that's parsed into a `Name` Python object. Then that name can be turned into various casings with `.CamelCase()` `.SNAKE_CASE()`, etc. When `dawn.json` things reference each other, it is always via these "canonical names".
diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py
index 84a2b99..858be52 100644
--- a/generator/dawn_json_generator.py
+++ b/generator/dawn_json_generator.py
@@ -765,6 +765,14 @@
renders.append(
FileRender('dawn_native/ProcTable.cpp',
'src/dawn_native/ProcTable.cpp', frontend_params))
+ renders.append(
+ FileRender('dawn_native/ChainUtils.h',
+ 'src/dawn_native/ChainUtils_autogen.h',
+ frontend_params))
+ renders.append(
+ FileRender('dawn_native/ChainUtils.cpp',
+ 'src/dawn_native/ChainUtils_autogen.cpp',
+ frontend_params))
if 'dawn_wire' in targets:
additional_params = compute_wire_params(api_params, wire_json)
diff --git a/generator/templates/dawn_native/ChainUtils.cpp b/generator/templates/dawn_native/ChainUtils.cpp
new file mode 100644
index 0000000..2a42db2
--- /dev/null
+++ b/generator/templates/dawn_native/ChainUtils.cpp
@@ -0,0 +1,61 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dawn_native/ChainUtils_autogen.h"
+
+#include <unordered_set>
+
+namespace dawn_native {
+
+{% for value in types["s type"].values %}
+ {% if value.valid %}
+ void FindInChain(const ChainedStruct* chain, const {{as_cppEnum(value.name)}}** out) {
+ for (; chain; chain = chain->nextInChain) {
+ if (chain->sType == wgpu::SType::{{as_cppEnum(value.name)}}) {
+ *out = static_cast<const {{as_cppEnum(value.name)}}*>(chain);
+ break;
+ }
+ }
+ }
+ {% endif %}
+{% endfor %}
+
+MaybeError ValidateSTypes(const ChainedStruct* chain,
+ std::vector<std::vector<wgpu::SType>> oneOfConstraints) {
+ std::unordered_set<wgpu::SType> allSTypes;
+ for (; chain; chain = chain->nextInChain) {
+ if (allSTypes.find(chain->sType) != allSTypes.end()) {
+ return DAWN_VALIDATION_ERROR("Chain cannot have duplicate sTypes");
+ }
+ allSTypes.insert(chain->sType);
+ }
+ for (const auto& oneOfConstraint : oneOfConstraints) {
+ bool satisfied = false;
+ for (wgpu::SType oneOfSType : oneOfConstraint) {
+ if (allSTypes.find(oneOfSType) != allSTypes.end()) {
+ if (satisfied) {
+ return DAWN_VALIDATION_ERROR("Unsupported sType combination");
+ }
+ satisfied = true;
+ allSTypes.erase(oneOfSType);
+ }
+ }
+ }
+ if (!allSTypes.empty()) {
+ return DAWN_VALIDATION_ERROR("Unsupported sType");
+ }
+ return {};
+}
+
+} // namespace dawn_native
diff --git a/generator/templates/dawn_native/ChainUtils.h b/generator/templates/dawn_native/ChainUtils.h
new file mode 100644
index 0000000..ce46591
--- /dev/null
+++ b/generator/templates/dawn_native/ChainUtils.h
@@ -0,0 +1,81 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DAWNNATIVE_CHAIN_UTILS_H_
+#define DAWNNATIVE_CHAIN_UTILS_H_
+
+#include "dawn_native/dawn_platform.h"
+#include "dawn_native/Error.h"
+
+namespace dawn_native {
+ {% for value in types["s type"].values %}
+ {% if value.valid %}
+ void FindInChain(const ChainedStruct* chain, const {{as_cppEnum(value.name)}}** out);
+ {% endif %}
+ {% endfor %}
+
+ // Verifies that |chain| only contains ChainedStructs of types enumerated in
+ // |oneOfConstraints| and contains no duplicate sTypes. Each vector in
+ // |oneOfConstraints| defines a set of sTypes that cannot coexist in the same chain.
+ // For example:
+ // ValidateSTypes(chain, { { ShaderModuleSPIRVDescriptor, ShaderModuleWGSLDescriptor } }))
+ // ValidateSTypes(chain, { { Extension1 }, { Extension2 } })
+ MaybeError ValidateSTypes(const ChainedStruct* chain,
+ std::vector<std::vector<wgpu::SType>> oneOfConstraints);
+
+ template <typename T>
+ MaybeError ValidateSingleSTypeInner(const ChainedStruct* chain, T sType) {
+ if (chain->sType != sType) {
+ return DAWN_VALIDATION_ERROR("Unsupported sType");
+ }
+ return {};
+ }
+
+ template <typename T, typename... Args>
+ MaybeError ValidateSingleSTypeInner(const ChainedStruct* chain, T sType, Args... sTypes) {
+ if (chain->sType == sType) {
+ return {};
+ }
+ return ValidateSingleSTypeInner(chain, sTypes...);
+ }
+
+ // Verifies that |chain| contains a single ChainedStruct of type |sType| or no ChainedStructs
+ // at all.
+ template <typename T>
+ MaybeError ValidateSingleSType(const ChainedStruct* chain, T sType) {
+ if (chain == nullptr) {
+ return {};
+ }
+ if (chain->nextInChain != nullptr) {
+ return DAWN_VALIDATION_ERROR("Chain can only contain a single chained struct");
+ }
+ return ValidateSingleSTypeInner(chain, sType);
+ }
+
+ // Verifies that |chain| contains a single ChainedStruct with a type enumerated in the
+ // parameter pack or no ChainedStructs at all.
+ template <typename T, typename... Args>
+ MaybeError ValidateSingleSType(const ChainedStruct* chain, T sType, Args... sTypes) {
+ if (chain == nullptr) {
+ return {};
+ }
+ if (chain->nextInChain != nullptr) {
+ return DAWN_VALIDATION_ERROR("Chain can only contain a single chained struct");
+ }
+ return ValidateSingleSTypeInner(chain, sType, sTypes...);
+ }
+
+} // namespace dawn_native
+
+#endif // DAWNNATIVE_CHAIN_UTILS_H_
diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn
index 803ee30..2a5a59b 100644
--- a/src/dawn_native/BUILD.gn
+++ b/src/dawn_native/BUILD.gn
@@ -91,6 +91,8 @@
dawn_json_generator("dawn_native_utils_gen") {
target = "dawn_native_utils"
outputs = [
+ "src/dawn_native/ChainUtils_autogen.h",
+ "src/dawn_native/ChainUtils_autogen.cpp",
"src/dawn_native/ProcTable.cpp",
"src/dawn_native/wgpu_structs_autogen.h",
"src/dawn_native/wgpu_structs_autogen.cpp",
diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp
index c09bdfb..c32034d 100644
--- a/src/dawn_native/RenderPipeline.cpp
+++ b/src/dawn_native/RenderPipeline.cpp
@@ -16,6 +16,7 @@
#include "common/BitSetIterator.h"
#include "common/VertexFormatUtils.h"
+#include "dawn_native/ChainUtils_autogen.h"
#include "dawn_native/Commands.h"
#include "dawn_native/Device.h"
#include "dawn_native/ObjectContentHasher.h"
@@ -133,16 +134,13 @@
MaybeError ValidatePrimitiveState(const DeviceBase* device,
const PrimitiveState* descriptor) {
- const ChainedStruct* chained = descriptor->nextInChain;
- if (chained != nullptr) {
- if (chained->sType != wgpu::SType::PrimitiveDepthClampingState) {
- return DAWN_VALIDATION_ERROR("Unsupported sType");
- }
- if (!device->IsExtensionEnabled(Extension::DepthClamping)) {
- return DAWN_VALIDATION_ERROR("The depth clamping feature is not supported");
- }
+ DAWN_TRY(ValidateSingleSType(descriptor->nextInChain,
+ wgpu::SType::PrimitiveDepthClampingState));
+ const PrimitiveDepthClampingState* clampInfo = nullptr;
+ FindInChain(descriptor->nextInChain, &clampInfo);
+ if (clampInfo && !device->IsExtensionEnabled(Extension::DepthClamping)) {
+ return DAWN_VALIDATION_ERROR("The depth clamping feature is not supported");
}
-
DAWN_TRY(ValidatePrimitiveTopology(descriptor->topology));
DAWN_TRY(ValidateIndexFormat(descriptor->stripIndexFormat));
DAWN_TRY(ValidateFrontFace(descriptor->frontFace));
@@ -426,11 +424,10 @@
}
mPrimitive = descriptor->primitive;
- const ChainedStruct* chained = mPrimitive.nextInChain;
- if (chained != nullptr) {
- ASSERT(chained->sType == wgpu::SType::PrimitiveDepthClampingState);
- const auto* clampState = static_cast<const PrimitiveDepthClampingState*>(chained);
- mClampDepth = clampState->clampDepth;
+ const PrimitiveDepthClampingState* clampInfo = nullptr;
+ FindInChain(mPrimitive.nextInChain, &clampInfo);
+ if (clampInfo) {
+ mClampDepth = clampInfo->clampDepth;
}
mMultisample = descriptor->multisample;
diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp
index a9cceac..5e5762f 100644
--- a/src/dawn_native/ShaderModule.cpp
+++ b/src/dawn_native/ShaderModule.cpp
@@ -17,6 +17,7 @@
#include "common/HashUtils.h"
#include "common/VertexFormatUtils.h"
#include "dawn_native/BindGroupLayout.h"
+#include "dawn_native/ChainUtils_autogen.h"
#include "dawn_native/CompilationMessages.h"
#include "dawn_native/Device.h"
#include "dawn_native/ObjectContentHasher.h"
@@ -1069,65 +1070,56 @@
return DAWN_VALIDATION_ERROR("Shader module descriptor missing chained descriptor");
}
// For now only a single SPIRV or WGSL subdescriptor is allowed.
- if (chainedDescriptor->nextInChain != nullptr) {
- return DAWN_VALIDATION_ERROR(
- "Shader module descriptor chained nextInChain must be nullptr");
- }
+ DAWN_TRY(ValidateSingleSType(chainedDescriptor,
+ wgpu::SType::ShaderModuleSPIRVDescriptor,
+ wgpu::SType::ShaderModuleWGSLDescriptor));
OwnedCompilationMessages* outMessages = parseResult->compilationMessages.get();
ScopedTintICEHandler scopedICEHandler(device);
- switch (chainedDescriptor->sType) {
- case wgpu::SType::ShaderModuleSPIRVDescriptor: {
- const auto* spirvDesc =
- static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor);
- std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
- if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
- tint::Program program;
- DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages));
- parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
- } else {
- if (device->IsValidationEnabled()) {
- DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
- }
- parseResult->spirv = std::move(spirv);
- }
- break;
- }
+ const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr;
+ FindInChain(chainedDescriptor, &spirvDesc);
+ const ShaderModuleWGSLDescriptor* wgslDesc = nullptr;
+ FindInChain(chainedDescriptor, &wgslDesc);
- case wgpu::SType::ShaderModuleWGSLDescriptor: {
- const auto* wgslDesc =
- static_cast<const ShaderModuleWGSLDescriptor*>(chainedDescriptor);
-
- auto tintSource = std::make_unique<TintSource>("", wgslDesc->source);
-
+ if (spirvDesc) {
+ std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+ if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
tint::Program program;
- DAWN_TRY_ASSIGN(program, ParseWGSL(&tintSource->file, outMessages));
-
- if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
- parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
- parseResult->tintSource = std::move(tintSource);
- } else {
- tint::transform::Manager transformManager;
- transformManager.Add<tint::transform::EmitVertexPointSize>();
- transformManager.Add<tint::transform::Spirv>();
-
- tint::transform::DataMap transformInputs;
-
- DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program,
- transformInputs, nullptr, outMessages));
-
- std::vector<uint32_t> spirv;
- DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program));
+ DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages));
+ parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
+ } else {
+ if (device->IsValidationEnabled()) {
DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
-
- parseResult->spirv = std::move(spirv);
}
- break;
+ parseResult->spirv = std::move(spirv);
}
- default:
- return DAWN_VALIDATION_ERROR("Unsupported sType");
+ } else if (wgslDesc) {
+ auto tintSource = std::make_unique<TintSource>("", wgslDesc->source);
+
+ tint::Program program;
+ DAWN_TRY_ASSIGN(program, ParseWGSL(&tintSource->file, outMessages));
+
+ if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
+ parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
+ parseResult->tintSource = std::move(tintSource);
+ } else {
+ tint::transform::Manager transformManager;
+ transformManager.Add<tint::transform::EmitVertexPointSize>();
+ transformManager.Add<tint::transform::Spirv>();
+
+ tint::transform::DataMap transformInputs;
+
+ DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, &program,
+ transformInputs, nullptr, outMessages));
+
+ std::vector<uint32_t> spirv;
+ DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(&program));
+ DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
+
+ parseResult->spirv = std::move(spirv);
+ }
}
return {};
@@ -1216,23 +1208,18 @@
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor)
: CachedObject(device), mType(Type::Undefined) {
ASSERT(descriptor->nextInChain != nullptr);
- switch (descriptor->nextInChain->sType) {
- case wgpu::SType::ShaderModuleSPIRVDescriptor: {
- mType = Type::Spirv;
- const auto* spirvDesc =
- static_cast<const ShaderModuleSPIRVDescriptor*>(descriptor->nextInChain);
- mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
- break;
- }
- case wgpu::SType::ShaderModuleWGSLDescriptor: {
- mType = Type::Wgsl;
- const auto* wgslDesc =
- static_cast<const ShaderModuleWGSLDescriptor*>(descriptor->nextInChain);
- mWgsl = std::string(wgslDesc->source);
- break;
- }
- default:
- UNREACHABLE();
+ const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr;
+ FindInChain(descriptor->nextInChain, &spirvDesc);
+ const ShaderModuleWGSLDescriptor* wgslDesc = nullptr;
+ FindInChain(descriptor->nextInChain, &wgslDesc);
+ ASSERT(spirvDesc || wgslDesc);
+
+ if (spirvDesc) {
+ mType = Type::Spirv;
+ mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
+ } else if (wgslDesc) {
+ mType = Type::Wgsl;
+ mWgsl = std::string(wgslDesc->source);
}
}
diff --git a/src/dawn_native/Surface.cpp b/src/dawn_native/Surface.cpp
index 4afe05e..9b317bc 100644
--- a/src/dawn_native/Surface.cpp
+++ b/src/dawn_native/Surface.cpp
@@ -15,6 +15,7 @@
#include "dawn_native/Surface.h"
#include "common/Platform.h"
+#include "dawn_native/ChainUtils_autogen.h"
#include "dawn_native/Instance.h"
#include "dawn_native/SwapChain.h"
@@ -34,75 +35,60 @@
MaybeError ValidateSurfaceDescriptor(const InstanceBase* instance,
const SurfaceDescriptor* descriptor) {
- // TODO(cwallez@chromium.org): Have some type of helper to iterate over all the chained
- // structures.
if (descriptor->nextInChain == nullptr) {
return DAWN_VALIDATION_ERROR("Surface cannot be created with just the base descriptor");
}
- const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
- if (chainedDescriptor->nextInChain != nullptr) {
- return DAWN_VALIDATION_ERROR("Cannot specify two windows for a single surface");
- }
+ DAWN_TRY(ValidateSingleSType(descriptor->nextInChain,
+ wgpu::SType::SurfaceDescriptorFromMetalLayer,
+ wgpu::SType::SurfaceDescriptorFromWindowsHWND,
+ wgpu::SType::SurfaceDescriptorFromXlib));
- switch (chainedDescriptor->sType) {
#if defined(DAWN_ENABLE_BACKEND_METAL)
- case wgpu::SType::SurfaceDescriptorFromMetalLayer: {
- const SurfaceDescriptorFromMetalLayer* metalDesc =
- static_cast<const SurfaceDescriptorFromMetalLayer*>(chainedDescriptor);
-
- // Check that the layer is a CAMetalLayer (or a derived class).
- if (!InheritsFromCAMetalLayer(metalDesc->layer)) {
- return DAWN_VALIDATION_ERROR("layer must be a CAMetalLayer");
- }
- break;
- }
+ const SurfaceDescriptorFromMetalLayer* metalDesc = nullptr;
+ FindInChain(descriptor->nextInChain, &metalDesc);
+ if (!metalDesc) {
+ return DAWN_VALIDATION_ERROR("Unsupported sType");
+ }
+ // Check that the layer is a CAMetalLayer (or a derived class).
+ if (!InheritsFromCAMetalLayer(metalDesc->layer)) {
+ return DAWN_VALIDATION_ERROR("layer must be a CAMetalLayer");
+ }
#endif // defined(DAWN_ENABLE_BACKEND_METAL)
#if defined(DAWN_PLATFORM_WINDOWS)
- case wgpu::SType::SurfaceDescriptorFromWindowsHWND: {
- const SurfaceDescriptorFromWindowsHWND* hwndDesc =
- static_cast<const SurfaceDescriptorFromWindowsHWND*>(chainedDescriptor);
-
- // It is not possible to validate an HINSTANCE.
-
- // Validate the hwnd using the windows.h IsWindow function.
- if (IsWindow(static_cast<HWND>(hwndDesc->hwnd)) == 0) {
- return DAWN_VALIDATION_ERROR("Invalid HWND");
- }
- break;
- }
+ const SurfaceDescriptorFromWindowsHWND* hwndDesc = nullptr;
+ FindInChain(descriptor->nextInChain, &hwndDesc);
+ if (!hwndDesc) {
+ return DAWN_VALIDATION_ERROR("Unsupported sType");
+ }
+ // Validate the hwnd using the windows.h IsWindow function.
+ if (IsWindow(static_cast<HWND>(hwndDesc->hwnd)) == 0) {
+ return DAWN_VALIDATION_ERROR("Invalid HWND");
+ }
#endif // defined(DAWN_PLATFORM_WINDOWS)
#if defined(DAWN_USE_X11)
- case wgpu::SType::SurfaceDescriptorFromXlib: {
- const SurfaceDescriptorFromXlib* xDesc =
- static_cast<const SurfaceDescriptorFromXlib*>(chainedDescriptor);
-
- // It is not possible to validate an X Display.
-
- // Check the validity of the window by calling a getter function on the window that
- // returns a status code. If the window is bad the call return a status of zero. We
- // need to set a temporary X11 error handler while doing this because the default
- // X11 error handler exits the program on any error.
- XErrorHandler oldErrorHandler =
- XSetErrorHandler([](Display*, XErrorEvent*) { return 0; });
- XWindowAttributes attributes;
- int status = XGetWindowAttributes(reinterpret_cast<Display*>(xDesc->display),
- xDesc->window, &attributes);
- XSetErrorHandler(oldErrorHandler);
-
- if (status == 0) {
- return DAWN_VALIDATION_ERROR("Invalid X Window");
- }
- break;
- }
-#endif // defined(DAWN_USE_X11)
-
- case wgpu::SType::SurfaceDescriptorFromCanvasHTMLSelector:
- default:
- return DAWN_VALIDATION_ERROR("Unsupported sType");
+ const SurfaceDescriptorFromXlib* xDesc = nullptr;
+ FindInChain(descriptor->nextInChain, &xDesc);
+ if (!xDesc) {
+ return DAWN_VALIDATION_ERROR("Unsupported sType");
}
+ // Check the validity of the window by calling a getter function on the window that
+ // returns a status code. If the window is bad the call return a status of zero. We
+ // need to set a temporary X11 error handler while doing this because the default
+ // X11 error handler exits the program on any error.
+ XErrorHandler oldErrorHandler =
+ XSetErrorHandler([](Display*, XErrorEvent*) { return 0; });
+ XWindowAttributes attributes;
+ int status = XGetWindowAttributes(reinterpret_cast<Display*>(xDesc->display),
+ xDesc->window, &attributes);
+ XSetErrorHandler(oldErrorHandler);
+
+ if (status == 0) {
+ return DAWN_VALIDATION_ERROR("Invalid X Window");
+ }
+#endif // defined(DAWN_USE_X11)
return {};
}
@@ -110,37 +96,24 @@
Surface::Surface(InstanceBase* instance, const SurfaceDescriptor* descriptor)
: mInstance(instance) {
ASSERT(descriptor->nextInChain != nullptr);
- const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
-
- switch (chainedDescriptor->sType) {
- case wgpu::SType::SurfaceDescriptorFromMetalLayer: {
- const SurfaceDescriptorFromMetalLayer* metalDesc =
- static_cast<const SurfaceDescriptorFromMetalLayer*>(chainedDescriptor);
- mType = Type::MetalLayer;
- mMetalLayer = metalDesc->layer;
- break;
- }
-
- case wgpu::SType::SurfaceDescriptorFromWindowsHWND: {
- const SurfaceDescriptorFromWindowsHWND* hwndDesc =
- static_cast<const SurfaceDescriptorFromWindowsHWND*>(chainedDescriptor);
- mType = Type::WindowsHWND;
- mHInstance = hwndDesc->hinstance;
- mHWND = hwndDesc->hwnd;
- break;
- }
-
- case wgpu::SType::SurfaceDescriptorFromXlib: {
- const SurfaceDescriptorFromXlib* xDesc =
- static_cast<const SurfaceDescriptorFromXlib*>(chainedDescriptor);
- mType = Type::Xlib;
- mXDisplay = xDesc->display;
- mXWindow = xDesc->window;
- break;
- }
-
- default:
- UNREACHABLE();
+ const SurfaceDescriptorFromMetalLayer* metalDesc = nullptr;
+ const SurfaceDescriptorFromWindowsHWND* hwndDesc = nullptr;
+ const SurfaceDescriptorFromXlib* xDesc = nullptr;
+ FindInChain(descriptor->nextInChain, &metalDesc);
+ FindInChain(descriptor->nextInChain, &hwndDesc);
+ FindInChain(descriptor->nextInChain, &xDesc);
+ ASSERT(metalDesc || hwndDesc || xDesc);
+ if (metalDesc) {
+ mType = Type::MetalLayer;
+ mMetalLayer = metalDesc->layer;
+ } else if (hwndDesc) {
+ mType = Type::WindowsHWND;
+ mHInstance = hwndDesc->hinstance;
+ mHWND = hwndDesc->hwnd;
+ } else if (xDesc) {
+ mType = Type::Xlib;
+ mXDisplay = xDesc->display;
+ mXWindow = xDesc->window;
}
}
diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn
index faf4b22..284cc41c 100644
--- a/src/tests/BUILD.gn
+++ b/src/tests/BUILD.gn
@@ -156,6 +156,7 @@
"unittests/BitSetIteratorTests.cpp",
"unittests/BuddyAllocatorTests.cpp",
"unittests/BuddyMemoryAllocatorTests.cpp",
+ "unittests/ChainUtilsTests.cpp",
"unittests/CommandAllocatorTests.cpp",
"unittests/EnumClassBitmasksTests.cpp",
"unittests/EnumMaskIteratorTests.cpp",
diff --git a/src/tests/unittests/ChainUtilsTests.cpp b/src/tests/unittests/ChainUtilsTests.cpp
new file mode 100644
index 0000000..2d43729
--- /dev/null
+++ b/src/tests/unittests/ChainUtilsTests.cpp
@@ -0,0 +1,181 @@
+// Copyright 2021 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gtest/gtest.h>
+
+#include "dawn_native/ChainUtils_autogen.h"
+#include "dawn_native/dawn_platform.h"
+
+// Checks that we cannot find any structs in an empty chain
+TEST(ChainUtilsTests, FindEmptyChain) {
+ const dawn_native::PrimitiveDepthClampingState* info = nullptr;
+ dawn_native::FindInChain(nullptr, &info);
+
+ ASSERT_EQ(nullptr, info);
+}
+
+// Checks that searching a chain for a present struct returns that struct
+TEST(ChainUtilsTests, FindPresentInChain) {
+ dawn_native::PrimitiveDepthClampingState chain1;
+ dawn_native::ShaderModuleSPIRVDescriptor chain2;
+ chain1.nextInChain = &chain2;
+ const dawn_native::PrimitiveDepthClampingState* info1 = nullptr;
+ const dawn_native::ShaderModuleSPIRVDescriptor* info2 = nullptr;
+ dawn_native::FindInChain(&chain1, &info1);
+ dawn_native::FindInChain(&chain1, &info2);
+
+ ASSERT_NE(nullptr, info1);
+ ASSERT_NE(nullptr, info2);
+}
+
+// Checks that searching a chain for a struct that doesn't exist returns a nullptr
+TEST(ChainUtilsTests, FindMissingInChain) {
+ dawn_native::PrimitiveDepthClampingState chain1;
+ dawn_native::ShaderModuleSPIRVDescriptor chain2;
+ chain1.nextInChain = &chain2;
+ const dawn_native::SurfaceDescriptorFromMetalLayer* info = nullptr;
+ dawn_native::FindInChain(&chain1, &info);
+
+ ASSERT_EQ(nullptr, info);
+}
+
+// Checks that validation rejects chains with duplicate STypes
+TEST(ChainUtilsTests, ValidateDuplicateSTypes) {
+ dawn_native::PrimitiveDepthClampingState chain1;
+ dawn_native::ShaderModuleSPIRVDescriptor chain2;
+ dawn_native::PrimitiveDepthClampingState chain3;
+ chain1.nextInChain = &chain2;
+ chain2.nextInChain = &chain3;
+
+ dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, {});
+ ASSERT_TRUE(result.IsError());
+ result.AcquireError();
+}
+
+// Checks that validation rejects chains that contain unspecified STypes
+TEST(ChainUtilsTests, ValidateUnspecifiedSTypes) {
+ dawn_native::PrimitiveDepthClampingState chain1;
+ dawn_native::ShaderModuleSPIRVDescriptor chain2;
+ dawn_native::ShaderModuleWGSLDescriptor chain3;
+ chain1.nextInChain = &chain2;
+ chain2.nextInChain = &chain3;
+
+ dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, {
+ {wgpu::SType::PrimitiveDepthClampingState},
+ {wgpu::SType::ShaderModuleSPIRVDescriptor},
+ });
+ ASSERT_TRUE(result.IsError());
+ result.AcquireError();
+}
+
+// Checks that validation rejects chains that contain multiple STypes from the same oneof
+// constraint.
+TEST(ChainUtilsTests, ValidateOneOfFailure) {
+ dawn_native::PrimitiveDepthClampingState chain1;
+ dawn_native::ShaderModuleSPIRVDescriptor chain2;
+ dawn_native::ShaderModuleWGSLDescriptor chain3;
+ chain1.nextInChain = &chain2;
+ chain2.nextInChain = &chain3;
+
+ dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1,
+ {{wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor}});
+ ASSERT_TRUE(result.IsError());
+ result.AcquireError();
+}
+
+// Checks that validation accepts chains that match the constraints.
+TEST(ChainUtilsTests, ValidateSuccess) {
+ dawn_native::PrimitiveDepthClampingState chain1;
+ dawn_native::ShaderModuleSPIRVDescriptor chain2;
+ chain1.nextInChain = &chain2;
+
+ dawn_native::MaybeError result = dawn_native::ValidateSTypes(&chain1, {
+ {wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor},
+ {wgpu::SType::PrimitiveDepthClampingState},
+ {wgpu::SType::SurfaceDescriptorFromMetalLayer},
+ });
+ ASSERT_TRUE(result.IsSuccess());
+}
+
+// Checks that validation always passes on empty chains.
+TEST(ChainUtilsTests, ValidateEmptyChain) {
+ dawn_native::MaybeError result = dawn_native::ValidateSTypes(nullptr, {
+ {wgpu::SType::ShaderModuleSPIRVDescriptor},
+ {wgpu::SType::PrimitiveDepthClampingState},
+ });
+ ASSERT_TRUE(result.IsSuccess());
+
+ result = dawn_native::ValidateSTypes(nullptr, {});
+ ASSERT_TRUE(result.IsSuccess());
+}
+
+// Checks that singleton validation always passes on empty chains.
+TEST(ChainUtilsTests, ValidateSingleEmptyChain) {
+ dawn_native::MaybeError result = dawn_native::ValidateSingleSType(nullptr,
+ wgpu::SType::ShaderModuleSPIRVDescriptor);
+ ASSERT_TRUE(result.IsSuccess());
+
+ result = dawn_native::ValidateSingleSType(nullptr,
+ wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::PrimitiveDepthClampingState);
+ ASSERT_TRUE(result.IsSuccess());
+}
+
+// Checks that singleton validation always fails on chains with multiple children.
+TEST(ChainUtilsTests, ValidateSingleMultiChain) {
+ dawn_native::PrimitiveDepthClampingState chain1;
+ dawn_native::ShaderModuleSPIRVDescriptor chain2;
+ chain1.nextInChain = &chain2;
+
+ dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1,
+ wgpu::SType::PrimitiveDepthClampingState);
+ ASSERT_TRUE(result.IsError());
+ result.AcquireError();
+
+ result = dawn_native::ValidateSingleSType(&chain1,
+ wgpu::SType::PrimitiveDepthClampingState, wgpu::SType::ShaderModuleSPIRVDescriptor);
+ ASSERT_TRUE(result.IsError());
+ result.AcquireError();
+}
+
+// Checks that singleton validation passes when the oneof constraint is met.
+TEST(ChainUtilsTests, ValidateSingleSatisfied) {
+ dawn_native::ShaderModuleWGSLDescriptor chain1;
+
+ dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1,
+ wgpu::SType::ShaderModuleWGSLDescriptor);
+ ASSERT_TRUE(result.IsSuccess());
+
+ result = dawn_native::ValidateSingleSType(&chain1,
+ wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor);
+ ASSERT_TRUE(result.IsSuccess());
+
+ result = dawn_native::ValidateSingleSType(&chain1,
+ wgpu::SType::ShaderModuleWGSLDescriptor, wgpu::SType::ShaderModuleSPIRVDescriptor);
+ ASSERT_TRUE(result.IsSuccess());
+}
+
+// Checks that singleton validation passes when the oneof constraint is not met.
+TEST(ChainUtilsTests, ValidateSingleUnsatisfied) {
+ dawn_native::PrimitiveDepthClampingState chain1;
+
+ dawn_native::MaybeError result = dawn_native::ValidateSingleSType(&chain1,
+ wgpu::SType::ShaderModuleWGSLDescriptor);
+ ASSERT_TRUE(result.IsError());
+ result.AcquireError();
+
+ result = dawn_native::ValidateSingleSType(&chain1,
+ wgpu::SType::ShaderModuleSPIRVDescriptor, wgpu::SType::ShaderModuleWGSLDescriptor);
+ ASSERT_TRUE(result.IsError());
+ result.AcquireError();
+}