Updates Chain utilities with wrapper and for Out structs as well.
- Adds Unpacked<T> template as wrapping util.
- Moves existing validation helpers into the class.
- Moves bitset into wrapping class so that it can be computed on
construction.
- Adds ValidateSubset helper to eventually replace ValidateSType
helpers.
- Fixes existing usages and updates tests.
Bug: dawn:1955
Change-Id: I72ad96ceab83a887a4a22ee1d12b8885fea36023
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/163600
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
diff --git a/generator/templates/dawn/native/ChainUtils.cpp b/generator/templates/dawn/native/ChainUtils.cpp
index 6966b96..9a36faf 100644
--- a/generator/templates/dawn/native/ChainUtils.cpp
+++ b/generator/templates/dawn/native/ChainUtils.cpp
@@ -29,7 +29,7 @@
{% set namespace_name = Name(metadata.native_namespace) %}
{% set native_namespace = namespace_name.namespace_case() %}
{% set native_dir = impl_dir + namespace_name.Dirs() %}
-#include "{{native_dir}}/ChainUtils_autogen.h"
+#include "{{native_dir}}/ChainUtils.h"
#include <tuple>
#include <unordered_set>
@@ -93,15 +93,18 @@
// Returns true iff the chain's SType matches the extension, false otherwise. If the SType was
// not already matched, sets the unpacked result accordingly. Otherwise, stores the duplicated
// SType in 'duplicate'.
-template <typename Root, typename Unpacked, typename Ext>
-bool UnpackExtension(Unpacked& unpacked, const ChainedStruct* chain, bool& duplicate) {
+template <typename Root, typename UnpackedT, typename Ext>
+bool UnpackExtension(typename UnpackedT::TupleType& unpacked,
+ typename UnpackedT::BitsetType& bitset,
+ typename UnpackedT::ChainType chain, bool* duplicate) {
DAWN_ASSERT(chain != nullptr);
if (chain->sType == STypeFor<Ext>) {
auto& member = std::get<Ext>(unpacked);
- if (member != nullptr) {
- duplicate = true;
+ if (member != nullptr && duplicate) {
+ *duplicate = true;
} else {
member = reinterpret_cast<Ext>(chain);
+ bitset.set(detail::UnpackedIndexOf<UnpackedT, Ext>);
}
return true;
}
@@ -110,13 +113,17 @@
// Tries to match all possible extensions, returning true iff one of the allowed extensions were
// matched, false otherwise. If the SType was not already matched, sets the unpacked result
-// accordingly. Otherwise, stores the diplicated SType in 'duplicate'.
-template <typename Root, typename Unpacked, typename AdditionalExts>
+// accordingly. Otherwise, stores the duplicated SType in 'duplicate'.
+template <typename Root, typename UnpackedT, typename AdditionalExts>
struct AdditionalExtensionUnpacker;
-template <typename Root, typename Unpacked, typename... Exts>
-struct AdditionalExtensionUnpacker<Root, Unpacked, detail::AdditionalExtensionsList<Exts...>> {
- static bool Unpack(Unpacked& unpacked, const ChainedStruct* chain, bool& duplicate) {
- return ((UnpackExtension<Root, Unpacked, Exts>(unpacked, chain, duplicate)) || ...);
+template <typename Root, typename UnpackedT, typename... Exts>
+struct AdditionalExtensionUnpacker<Root, UnpackedT, detail::AdditionalExtensionsList<Exts...>> {
+ static bool Unpack(typename UnpackedT::TupleType& unpacked,
+ typename UnpackedT::BitsetType& bitset,
+ typename UnpackedT::ChainType chain,
+ bool* duplicate) {
+ return ((UnpackExtension<Root, UnpackedT, Exts>(unpacked, bitset, chain, duplicate)) ||
+ ...);
}
};
@@ -124,52 +131,96 @@
// Unpacked chain helpers.
//
{% for type in by_category["structure"] %}
- {% if type.extensible == "in" %}
- {% set unpackedChain = "Unpacked" + as_cppType(type.name) + "Chain" %}
- ResultOrError<{{unpackedChain}}> ValidateAndUnpackChain(const {{as_cppType(type.name)}}* chain) {
- const ChainedStruct* next = chain->nextInChain;
- {{unpackedChain}} result;
-
- for (; next != nullptr; next = next->nextInChain) {
- bool duplicate = false;
- switch (next->sType) {
- {% for extension in type.extensions %}
- case STypeFor<{{as_cppType(extension.name)}}>: {
- auto& member = std::get<const {{as_cppType(extension.name)}}*>(result);
- if (member != nullptr) {
- duplicate = true;
- } else {
- member = static_cast<const {{as_cppType(extension.name)}}*>(next);
- }
- break;
- }
- {% endfor %}
- default: {
- using Unpacker =
- AdditionalExtensionUnpacker<
- {{as_cppType(type.name)}},
- {{unpackedChain}},
- detail::AdditionalExtensions<{{as_cppType(type.name)}}>::List>;
- if (!Unpacker::Unpack(result, next, duplicate)) {
- return DAWN_VALIDATION_ERROR(
- "Unexpected chained struct of type %s found on %s chain.",
- next->sType, "{{as_cppType(type.name)}}"
+ {% if not type.extensible %}
+ {% continue %}
+ {% endif %}
+ {% set T = as_cppType(type.name) %}
+ {% set UnpackedT = "Unpacked<" + T + ">" %}
+ template <>
+ {{UnpackedT}} Unpack<{{T}}>(typename {{UnpackedT}}::PtrType chain) {
+ {{UnpackedT}} result(chain);
+ for (typename {{UnpackedT}}::ChainType next = chain->nextInChain;
+ next != nullptr;
+ next = next->nextInChain) {
+ switch (next->sType) {
+ {% for extension in type.extensions %}
+ {% set Ext = as_cppType(extension.name) %}
+ case STypeFor<{{Ext}}>: {
+ using ExtPtrType =
+ typename detail::PtrTypeFor<{{UnpackedT}}, {{Ext}}>::Type;
+ std::get<ExtPtrType>(result.mUnpacked) =
+ static_cast<ExtPtrType>(next);
+ result.mBitset.set(
+ detail::UnpackedIndexOf<{{UnpackedT}}, ExtPtrType>
+ );
+ break;
+ }
+ {% endfor %}
+ default: {
+ using Unpacker =
+ AdditionalExtensionUnpacker<
+ {{T}},
+ {{UnpackedT}},
+ detail::AdditionalExtensions<{{T}}>::List>;
+ Unpacker::Unpack(result.mUnpacked, result.mBitset, next, nullptr);
+ break;
+ }
+ }
+ }
+ return result;
+ }
+ template <>
+ ResultOrError<{{UnpackedT}}> ValidateAndUnpack<{{T}}>(typename {{UnpackedT}}::PtrType chain) {
+ {{UnpackedT}} result(chain);
+ for (typename {{UnpackedT}}::ChainType next = chain->nextInChain;
+ next != nullptr;
+ next = next->nextInChain) {
+ bool duplicate = false;
+ switch (next->sType) {
+ {% for extension in type.extensions %}
+ {% set Ext = as_cppType(extension.name) %}
+ case STypeFor<{{Ext}}>: {
+ using ExtPtrType =
+ typename detail::PtrTypeFor<{{UnpackedT}}, {{Ext}}>::Type;
+ auto& member = std::get<ExtPtrType>(result.mUnpacked);
+ if (member != nullptr) {
+ duplicate = true;
+ } else {
+ member = static_cast<ExtPtrType>(next);
+ result.mBitset.set(
+ detail::UnpackedIndexOf<{{UnpackedT}}, ExtPtrType>
);
}
break;
}
- }
- if (duplicate) {
- return DAWN_VALIDATION_ERROR(
- "Duplicate chained struct of type %s found on %s chain.",
- next->sType, "{{as_cppType(type.name)}}"
- );
+ {% endfor %}
+ default: {
+ using Unpacker =
+ AdditionalExtensionUnpacker<
+ {{T}},
+ {{UnpackedT}},
+ detail::AdditionalExtensions<{{T}}>::List>;
+ if (!Unpacker::Unpack(result.mUnpacked,
+ result.mBitset,
+ next,
+ &duplicate)) {
+ return DAWN_VALIDATION_ERROR(
+ "Unexpected chained struct of type %s found on %s chain.",
+ next->sType, "{{T}}"
+ );
+ }
+ break;
}
}
- return result;
+ if (duplicate) {
+ return DAWN_VALIDATION_ERROR(
+ "Duplicate chained struct of type %s found on %s chain.",
+ next->sType, "{{T}}"
+ );
+ }
}
-
- {% endif %}
+ return result;
+ }
{% endfor %}
} // namespace {{native_namespace}}
diff --git a/generator/templates/dawn/native/ChainUtils.h b/generator/templates/dawn/native/ChainUtils.h
index bf3b991..8c220b4 100644
--- a/generator/templates/dawn/native/ChainUtils.h
+++ b/generator/templates/dawn/native/ChainUtils.h
@@ -37,8 +37,6 @@
{% set native_dir = impl_dir + namespace_name.Dirs() %}
{% set prefix = metadata.proc_table_prefix.lower() %}
#include <tuple>
-#include <type_traits>
-#include <unordered_set>
#include "absl/strings/str_format.h"
#include "{{native_dir}}/{{prefix}}_platform.h"
@@ -46,20 +44,20 @@
#include "{{native_dir}}/{{namespace}}_structs_autogen.h"
namespace {{native_namespace}} {
-
namespace detail {
// SType for implementation details. Kept inside the detail namespace for extensibility.
template <typename T>
inline {{namespace}}::SType STypeForImpl;
- // Specialize STypeFor to map from native struct types to their SType.
- {% for value in types["s type"].values %}
- {% if value.valid and value.name.get() in types %}
- template <>
- constexpr inline {{namespace}}::SType STypeForImpl<{{as_cppEnum(value.name)}}> = {{namespace}}::SType::{{as_cppEnum(value.name)}};
- {% endif %}
- {% endfor %}
+// Specialize STypeFor to map from native struct types to their SType.
+{% for value in types["s type"].values %}
+ {% if value.valid and value.name.get() in types %}
+ template <>
+ constexpr inline {{namespace}}::SType STypeForImpl<{{as_cppEnum(value.name)}}> =
+ {{namespace}}::SType::{{as_cppEnum(value.name)}};
+ {% endif %}
+{% endfor %}
template <typename Arg, typename... Rest>
std::string STypesToString() {
@@ -70,12 +68,6 @@
}
}
-//
-// Unpacked chain types structs and helpers.
-// Note that unpacked types are tuples to enable further templating extensions based on
-// typing via something like std::get<const Extension*> in templated functions.
-//
-
// Typelist type used to further add extensions to chain roots when they are not in the json.
template <typename... Exts>
struct AdditionalExtensionsList;
@@ -94,25 +86,6 @@
using Type = std::tuple<Ts..., Additionals...>;
};
-// Template function that returns a string of the non-nullptr STypes from an unpacked chain.
-template <typename Unpacked>
-std::string UnpackedChainToString(const Unpacked& unpacked) {
- std::string result = "( ";
- std::apply(
- [&](const auto*... args) {
- (([&](const auto* arg) {
- if (arg != nullptr) {
- // reinterpret_cast because this chained struct might be forward-declared
- // without a definition. The definition may only be available on a
- // particular backend.
- const auto* chainedStruct = reinterpret_cast<const wgpu::ChainedStruct*>(arg);
- result += absl::StrFormat("%s, ", chainedStruct->sType);
- }
- }(args)), ...);}, unpacked);
- result += " )";
- return result;
-}
-
} // namespace detail
template <typename T>
@@ -220,36 +193,56 @@
return ValidateSingleSTypeInner(chain, sType, sTypes...);
}
-// Template type to get root type from the unpacked chain and vice-versa.
-template <typename Unpacked>
-struct RootTypeFor;
-template <typename Root>
-struct UnpackedTypeFor;
-
} // namespace {{native_namespace}}
// Include specializations before declaring types for ordering purposes.
#include "{{native_dir}}/ChainUtilsImpl.inl"
namespace {{native_namespace}} {
+namespace detail {
+
+// Template type to get the unpacked chain type from the root type.
+template <typename Root>
+struct UnpackedTypeFor;
+
+// Template for extensible structures typing.
+enum class Extensibility { In, Out };
+template <typename T>
+inline Extensibility ExtensibilityFor;
{% for type in by_category["structure"] %}
+ {% set T = as_cppType(type.name) %}
{% if type.extensible == "in" %}
- {% set unpackedChain = "Unpacked" + as_cppType(type.name) + "Chain" %}
- using {{unpackedChain}} = detail::UnpackedChain<
- detail::AdditionalExtensions<{{as_cppType(type.name)}}>::List{{ "," if len(type.extensions) != 0 else ""}}
- {% for extension in type.extensions %}
- const {{as_cppType(extension.name)}}*{{ "," if not loop.last else "" }}
- {% endfor %}
- >::Type;
template <>
- struct UnpackedTypeFor<{{as_cppType(type.name)}}> {
- using Type = {{unpackedChain}};
+ struct UnpackedTypeFor<{{T}}> {
+ using Type = UnpackedChain<
+ AdditionalExtensions<{{T}}>::List
+ {% for extension in type.extensions %}
+ , const {{as_cppType(extension.name)}}*
+ {% endfor %}
+ >::Type;
};
- ResultOrError<{{unpackedChain}}> ValidateAndUnpackChain(const {{as_cppType(type.name)}}* chain);
+ template <>
+ constexpr inline Extensibility ExtensibilityFor<{{T}}> = Extensibility::In;
+
+ {% elif type.extensible == "out" %}
+ template <>
+ struct UnpackedTypeFor<{{T}}> {
+ using Type = UnpackedChain<
+ AdditionalExtensions<{{T}}>::List
+ {% for extension in type.extensions %}
+ , {{as_cppType(extension.name)}}*
+ {% endfor %}
+ >::Type;
+ };
+ template <>
+ constexpr inline Extensibility ExtensibilityFor<{{T}}> = Extensibility::Out;
+
{% endif %}
{% endfor %}
+} // namespace detail
+
} // namespace {{native_namespace}}
#endif // {{DIR}}_CHAIN_UTILS_H_
diff --git a/src/dawn/native/Buffer.cpp b/src/dawn/native/Buffer.cpp
index 1a5407c..83f7116 100644
--- a/src/dawn/native/Buffer.cpp
+++ b/src/dawn/native/Buffer.cpp
@@ -215,12 +215,12 @@
};
MaybeError ValidateBufferDescriptor(DeviceBase* device, const BufferDescriptor* descriptor) {
- UnpackedBufferDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<BufferDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
DAWN_TRY(ValidateBufferUsage(descriptor->usage));
- if (const auto* hostMappedDesc = std::get<const BufferHostMappedPointer*>(unpacked)) {
+ if (const auto* hostMappedDesc = unpacked.Get<BufferHostMappedPointer>()) {
// TODO(crbug.com/dawn/2018): Properly expose this limit.
uint32_t requiredAlignment = 4096;
if (device->GetAdapter()->GetPhysicalDevice()->GetBackendType() ==
diff --git a/src/dawn/native/ChainUtils.h b/src/dawn/native/ChainUtils.h
index 25a2a57..65b840f 100644
--- a/src/dawn/native/ChainUtils.h
+++ b/src/dawn/native/ChainUtils.h
@@ -31,6 +31,7 @@
#include <bitset>
#include <string>
#include <tuple>
+#include <type_traits>
#include "absl/strings/str_format.h"
#include "dawn/common/Math.h"
@@ -38,56 +39,143 @@
#include "dawn/native/Error.h"
namespace dawn::native {
+namespace detail {
+// Gets the chain type for an extensible type.
+template <typename T>
+struct ChainTypeFor {
+ using Type = typename std::conditional_t<ExtensibilityFor<T> == Extensibility::In,
+ const wgpu::ChainedStruct*,
+ wgpu::ChainedStructOut*>;
+};
+} // namespace detail
+
+template <typename T, typename ChainType>
+class UnpackedBase;
+template <typename T>
+using Unpacked = UnpackedBase<T, typename detail::ChainTypeFor<T>::Type>;
+
+namespace detail {
+// Converts to the expected pointer types depending on the extensibility of the structure.
+template <typename UnpackedT, typename U>
+struct PtrTypeFor;
+template <typename T, typename U>
+struct PtrTypeFor<Unpacked<T>, U> {
+ using Type =
+ typename std::conditional_t<ExtensibilityFor<T> == Extensibility::In, const U*, U*>;
+};
+} // namespace detail
+
+// Unpacks chained structures in a best effort manner (skipping unknown chains) without applying
+// validation. If the same structure is duplicated in the chain, it is unspecified which one the
+// result of Get will be. These are the effective constructors for the wrapper types. Note that
+// these are implemented in the generated ChainUtils_autogen.cpp file.
+template <typename T,
+ typename = std::enable_if_t<detail::ExtensibilityFor<T> == detail::Extensibility::In>>
+Unpacked<T> Unpack(const T* chain);
+template <typename T,
+ typename = std::enable_if_t<detail::ExtensibilityFor<T> == detail::Extensibility::Out>>
+Unpacked<T> Unpack(T* chain);
+
+// Unpacks chained structures into Unpacked<T> while applying validation.
+template <typename T,
+ typename = std::enable_if_t<detail::ExtensibilityFor<T> == detail::Extensibility::In>>
+ResultOrError<Unpacked<T>> ValidateAndUnpack(const T* chain);
+template <typename T,
+ typename = std::enable_if_t<detail::ExtensibilityFor<T> == detail::Extensibility::Out>>
+ResultOrError<Unpacked<T>> ValidateAndUnpack(T* chain);
+
+//
+// Wrapper class for unpacked pointers. The classes essentially acts like a const T* or T* with
+// the additional capabilities to validate and retrieve chained structures.
+//
+template <typename T, typename ChainT>
+class UnpackedBase {
+ public:
+ using ChainType = ChainT;
+ using PtrType = typename detail::PtrTypeFor<UnpackedBase<T, ChainType>, T>::Type;
+ using TupleType = typename detail::UnpackedTypeFor<T>::Type;
+ using BitsetType = typename std::bitset<std::tuple_size_v<TupleType>>;
+
+ UnpackedBase() : mStruct(nullptr) {}
+
+ operator bool() const { return mStruct != nullptr; }
+ PtrType operator->() const { return mStruct; }
+ PtrType operator*() const { return mStruct; }
+
+ // Returns true iff every allowed chain in this unpacked type is nullptr.
+ bool Empty() const;
+ // Returns a string of the non-nullptr STypes from an unpacked chain.
+ std::string ToString() const;
+
+ template <typename In>
+ auto Get() const;
+
+ // Validation functions. See implementations of these below for usage, details, and examples.
+ template <typename... Branches>
+ ResultOrError<wgpu::SType> ValidateBranches() const;
+ template <typename... Allowed>
+ MaybeError ValidateSubset() const;
+
+ private:
+ friend UnpackedBase<T, ChainType> Unpack<T>(PtrType chain);
+ friend ResultOrError<UnpackedBase<T, ChainType>> ValidateAndUnpack<T>(PtrType chain);
+
+ explicit UnpackedBase(PtrType packed) : mStruct(packed) {}
+
+ PtrType mStruct = nullptr;
+ TupleType mUnpacked;
+ BitsetType mBitset;
+};
// Tuple type of a Branch and an optional list of corresponding Extensions.
template <typename B, typename... Exts>
struct Branch;
-// Typelist type used to specify a list of acceptable Branches.
-template <typename... Branches>
-struct BranchList;
+// ------------------------------------------------------------------------------------------------
+// Implementation details start here so that the headers are terse.
+// ------------------------------------------------------------------------------------------------
namespace detail {
// Helpers to get the index in an unpacked tuple type for a particular type.
-template <typename Unpacked, typename Ext>
-struct UnpackedIndexOf;
+template <typename UnpackedT, typename Ext>
+inline size_t UnpackedTupleIndexOf;
template <typename Ext, typename... Exts>
-struct UnpackedIndexOf<std::tuple<const Ext*, Exts...>, Ext> {
- static constexpr size_t value = 0;
-};
-template <typename Ext, typename... Exts>
-struct UnpackedIndexOf<std::tuple<Ext, Exts...>, Ext> {
- static constexpr size_t value = 0;
-};
+constexpr inline size_t UnpackedTupleIndexOf<std::tuple<Ext, Exts...>, Ext> = 0;
template <typename Ext, typename Other, typename... Exts>
-struct UnpackedIndexOf<std::tuple<Other, Exts...>, Ext> {
- static constexpr size_t value = 1 + UnpackedIndexOf<std::tuple<Exts...>, Ext>::value;
-};
+constexpr inline size_t UnpackedTupleIndexOf<std::tuple<Other, Exts...>, Ext> =
+ 1 + UnpackedTupleIndexOf<std::tuple<Exts...>, Ext>;
-template <typename Unpacked, typename... Exts>
-struct UnpackedBitsetForExts {
- // Currently using an internal 64-bit unsigned int for internal representation. This is
- // necessary because std::bitset::operator| is not constexpr until C++23.
- static constexpr auto value = std::bitset<std::tuple_size_v<Unpacked>>(
- ((uint64_t(1) << UnpackedIndexOf<Unpacked, Exts>::value) | ...));
-};
+template <typename UnpackedT, typename Ext>
+inline size_t UnpackedIndexOf;
+template <typename T, typename Ext>
+constexpr inline size_t UnpackedIndexOf<Unpacked<T>, Ext> =
+ UnpackedTupleIndexOf<typename Unpacked<T>::TupleType,
+ typename PtrTypeFor<Unpacked<T>, std::remove_pointer_t<Ext>>::Type>;
-template <typename Unpacked, typename...>
+// Currently using an internal 64-bit unsigned int for internal representation. This is necessary
+// because std::bitset::operator| is not constexpr until C++23.
+template <typename UnpackedT, typename... Exts>
+constexpr inline auto UnpackedBitsetForExts =
+ typename UnpackedT::BitsetType(((uint64_t(1) << UnpackedIndexOf<UnpackedT, Exts>) | ...));
+
+template <typename UnpackedT, typename...>
struct OneBranchValidator;
-template <typename Unpacked, typename B, typename... Exts>
-struct OneBranchValidator<Unpacked, Branch<B, Exts...>> {
- static bool Validate(const Unpacked& unpacked,
- const std::bitset<std::tuple_size_v<Unpacked>>& actual,
- wgpu::SType& match) {
+template <typename UnpackedT, typename R, typename... Exts>
+struct OneBranchValidator<UnpackedT, Branch<R, Exts...>> {
+ using BitsetType = typename UnpackedT::BitsetType;
+
+ static bool Validate(const UnpackedT& unpacked, const BitsetType& actual, wgpu::SType& match) {
// Only check the full bitset when the main branch matches.
- if (std::get<const B*>(unpacked) != nullptr) {
+ if (unpacked.template Get<R>() != nullptr) {
// Allowed set of extensions includes the branch root as well.
- constexpr auto allowed = UnpackedBitsetForExts<Unpacked, B, Exts...>::value;
+ constexpr auto allowed =
+ UnpackedBitsetForExts<UnpackedT, typename detail::PtrTypeFor<UnpackedT, R>::Type,
+ typename detail::PtrTypeFor<UnpackedT, Exts>::Type...>;
// The configuration is allowed if the actual available chains are a subset.
if (IsSubset(actual, allowed)) {
- match = STypeFor<B>;
+ match = STypeFor<R>;
return true;
}
}
@@ -96,58 +184,104 @@
static std::string ToString() {
if constexpr (sizeof...(Exts) > 0) {
- return absl::StrFormat("[ %s -> (%s) ]", STypesToString<B>(),
+ return absl::StrFormat("[ %s -> (%s) ]", STypesToString<R>(),
STypesToString<Exts...>());
} else {
- return absl::StrFormat("[ %s ]", STypesToString<B>());
+ return absl::StrFormat("[ %s ]", STypesToString<R>());
}
}
};
-template <typename Unpacked, typename List>
-struct BranchesValidator;
-template <typename Unpacked, typename... Branches>
-struct BranchesValidator<Unpacked, BranchList<Branches...>> {
- static bool Validate(const Unpacked& unpacked, wgpu::SType& match) {
- // Build a bitset based on which elements in the tuple are actually set. We are essentially
- // just looping over every element in the unpacked tuple, computing the index of the element
- // within the tuple, and setting the respective bit if the element is not nullptr.
- std::bitset<std::tuple_size_v<Unpacked>> actual;
- std::apply(
- [&](const auto*... args) {
- (actual.set(UnpackedIndexOf<Unpacked, decltype(args)>::value, args != nullptr),
- ...);
- },
- unpacked);
+template <typename UnpackedT, typename... Branches>
+struct BranchesValidator {
+ using BitsetType = typename UnpackedT::BitsetType;
- return ((OneBranchValidator<Unpacked, Branches>::Validate(unpacked, actual, match)) || ...);
+ static bool Validate(const UnpackedT& unpacked, const BitsetType& actual, wgpu::SType& match) {
+ return ((OneBranchValidator<UnpackedT, Branches>::Validate(unpacked, actual, match)) ||
+ ...);
}
static std::string ToString() {
- return ((absl::StrFormat(" - %s\n", OneBranchValidator<Unpacked, Branches>::ToString())) +
+ return ((absl::StrFormat(" - %s\n", OneBranchValidator<UnpackedT, Branches>::ToString())) +
...);
}
};
+template <typename UnpackedT, typename... Allowed>
+struct SubsetValidator {
+ using BitsetType = typename UnpackedT::BitsetType;
+
+ static std::string ToString() {
+ return absl::StrFormat("[ %s ]", STypesToString<Allowed...>());
+ }
+
+ static MaybeError Validate(const UnpackedT& unpacked, const BitsetType& bitset) {
+ // Allowed set of extensions includes the branch root as well.
+ constexpr auto allowed =
+ detail::UnpackedBitsetForExts<UnpackedT,
+ typename detail::PtrTypeFor<UnpackedT, Allowed>::Type...>;
+ if (!IsSubset(bitset, allowed)) {
+ return DAWN_VALIDATION_ERROR(
+ "Expected extension set to be a subset of: %s\nInstead found: %s", ToString(),
+ unpacked.ToString());
+ }
+ return {};
+ }
+};
+
} // namespace detail
-// Helper to validate that an unpacked chain retrieved via ValidateAndUnpackChain matches a valid
-// "branch", where a "branch" is defined as a "root" extension and optional follow-up extensions.
-// Returns the wgpu::SType associated with the "root" extension of a "branch" if matched, otherwise
-// returns an error.
+template <typename T, typename ChainType>
+template <typename In>
+auto UnpackedBase<T, ChainType>::Get() const {
+ return std::get<typename detail::PtrTypeFor<UnpackedBase<T, ChainType>, In>::Type>(mUnpacked);
+}
+
+template <typename T, typename ChainType>
+bool UnpackedBase<T, ChainType>::Empty() const {
+ return mBitset.none();
+}
+
+template <typename T, typename ChainType>
+std::string UnpackedBase<T, ChainType>::ToString() const {
+ std::string result = "( ";
+ std::apply(
+ [&](auto*... args) {
+ (([&](auto* arg) {
+ if (arg != nullptr) {
+ // reinterpret_cast because this chained struct might be forward-declared
+ // without a definition. The definition may only be available on a
+ // particular backend.
+ const auto* chainedStruct = reinterpret_cast<ChainType>(arg);
+ result += absl::StrFormat("%s, ", chainedStruct->sType);
+ }
+ }(args)),
+ ...);
+ },
+ mUnpacked);
+ result += " )";
+ return result;
+}
+
+// Validates that an unpacked chain retrieved via ValidateAndUnpack matches a valid "branch",
+// where a "branch" is defined as a required "root" extension and optional follow-up
+// extensions.
+//
+// Returns the wgpu::SType associated with the "root" extension of a "branch" if matched,
+// otherwise returns an error.
//
// Example usage:
-// UnpackedChain u;
-// DAWN_TRY_ASSIGN(u, ValidateAndUnpackChain(desc));
+// Unpacked<T> u;
+// DAWN_TRY_ASSIGN(u, ValidateAndUnpack(desc));
// wgpu::SType rootType;
// DAWN_TRY_ASSIGN(rootType,
-// ValidateBranches<BranchList<Branch<Root1>, Branch<Root2, R2Ext1>>>(u));
+// u.ValidateBranches<Branch<Root1>, Branch<Root2, R2Ext1>>());
// switch (rootType) {
// case Root1: {
// <do something>
// }
// case Root2: {
-// R2Ext1 ext = std::get<const R2Ext1*>(u);
+// R2Ext1 ext = u.Get<R2Ext1>(u);
// if (ext) {
// <do something with optional extension(s)>
// }
@@ -160,18 +294,36 @@
// - only a Root1 extension
// - or a Root2 extension with an optional R2Ext1 extension
// Any other configuration is deemed invalid.
-template <typename Branches, typename Unpacked>
-ResultOrError<wgpu::SType> ValidateBranches(const Unpacked& unpacked) {
- using Validator = detail::BranchesValidator<Unpacked, Branches>;
+template <typename T, typename ChainType>
+template <typename... Branches>
+ResultOrError<wgpu::SType> UnpackedBase<T, ChainType>::ValidateBranches() const {
+ using Validator = detail::BranchesValidator<UnpackedBase<T, ChainType>, Branches...>;
wgpu::SType match = wgpu::SType::Invalid;
- if (Validator::Validate(unpacked, match)) {
+ if (Validator::Validate(*this, mBitset, match)) {
return match;
}
return DAWN_VALIDATION_ERROR(
- "Expected chain root to match one of the following branch types with optional extensions:\n"
- "%sInstead found: %s",
- Validator::ToString(), detail::UnpackedChainToString(unpacked));
+ "Expected chain root to match one of the following branch types with optional extensions:"
+ "\n%sInstead found: %s",
+ Validator::ToString(), ToString());
+}
+
+// Validates that an unpacked chain retrieved via ValidateAndUnpack contains a subset of the
+// Allowed extensions. If there are any other extensions, returns an error.
+//
+// Example usage:
+// Unpacked<T> u;
+// DAWN_TRY_ASSIGN(u, ValidateAndUnpack(desc));
+// DAWN_TRY(u.ValidateSubset<Ext1>());
+//
+// Even though "valid" extensions on descriptor may include both Ext1 and Ext2, ValidateSubset
+// will further enforce that Ext2 is not on the chain in the example above.
+template <typename T, typename ChainType>
+template <typename... Allowed>
+MaybeError UnpackedBase<T, ChainType>::ValidateSubset() const {
+ return detail::SubsetValidator<UnpackedBase<T, ChainType>, Allowed...>::Validate(*this,
+ mBitset);
}
} // namespace dawn::native
diff --git a/src/dawn/native/ComputePipeline.cpp b/src/dawn/native/ComputePipeline.cpp
index 8de2613..6898ffe 100644
--- a/src/dawn/native/ComputePipeline.cpp
+++ b/src/dawn/native/ComputePipeline.cpp
@@ -36,10 +36,9 @@
MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
const ComputePipelineDescriptor* descriptor) {
- UnpackedComputePipelineDescriptorChain unpackedChain;
- DAWN_TRY_ASSIGN(unpackedChain, ValidateAndUnpackChain(descriptor));
- const auto* fullSubgroupsOption =
- std::get<const DawnComputePipelineFullSubgroups*>(unpackedChain);
+ Unpacked<ComputePipelineDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
+ const auto* fullSubgroupsOption = unpacked.Get<DawnComputePipelineFullSubgroups>();
DAWN_INVALID_IF(
(fullSubgroupsOption && !device->HasFeature(Feature::ChromiumExperimentalSubgroups)),
"DawnComputePipelineFullSubgroups is used without %s enabled.",
diff --git a/src/dawn/native/Instance.cpp b/src/dawn/native/Instance.cpp
index 84e76ba..a1f0236 100644
--- a/src/dawn/native/Instance.cpp
+++ b/src/dawn/native/Instance.cpp
@@ -206,15 +206,15 @@
// TODO(crbug.com/dawn/832): make the platform an initialization parameter of the instance.
MaybeError InstanceBase::Initialize(const InstanceDescriptor* descriptor) {
- UnpackedInstanceDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<InstanceDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
// Initialize the platform to the default for now.
mDefaultPlatform = std::make_unique<dawn::platform::Platform>();
SetPlatform(mDefaultPlatform.get());
// Process DawnInstanceDescriptor
- if (const auto* dawnDesc = std::get<const DawnInstanceDescriptor*>(unpacked)) {
+ if (const auto* dawnDesc = unpacked.Get<DawnInstanceDescriptor>()) {
for (uint32_t i = 0; i < dawnDesc->additionalRuntimeSearchPathsCount; ++i) {
mRuntimeSearchPaths.push_back(dawnDesc->additionalRuntimeSearchPaths[i]);
}
@@ -233,7 +233,7 @@
mCallbackTaskManager = AcquireRef(new CallbackTaskManager());
DAWN_TRY(mEventManager.Initialize(descriptor));
- GatherWGSLFeatures(std::get<const DawnWGSLBlocklist*>(unpacked));
+ GatherWGSLFeatures(unpacked.Get<DawnWGSLBlocklist>());
return {};
}
diff --git a/src/dawn/native/RenderPipeline.cpp b/src/dawn/native/RenderPipeline.cpp
index 3741a07..f5cc44e 100644
--- a/src/dawn/native/RenderPipeline.cpp
+++ b/src/dawn/native/RenderPipeline.cpp
@@ -317,10 +317,9 @@
descriptor->format, wgpu::CompareFunction::Undefined, descriptor->depthWriteEnabled,
descriptor->stencilFront.depthFailOp, descriptor->stencilBack.depthFailOp);
- UnpackedDepthStencilStateChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
- if (const auto* depthWriteDefined =
- std::get<const DepthStencilStateDepthWriteDefinedDawn*>(unpacked)) {
+ Unpacked<DepthStencilState> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
+ if (const auto* depthWriteDefined = unpacked.Get<DepthStencilStateDepthWriteDefinedDawn>()) {
DAWN_INVALID_IF(
format->HasDepth() && !depthWriteDefined->depthWriteDefined,
"Depth stencil format (%s) has a depth aspect and depthWriteEnabled is undefined.",
diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp
index e278e78..db40002 100644
--- a/src/dawn/native/ShaderModule.cpp
+++ b/src/dawn/native/ShaderModule.cpp
@@ -1004,26 +1004,26 @@
tint::Source::File file;
};
-// A WGSL (or SPIR-V, if enabled) subdescriptor is required, and a Dawn-specific SPIR-V options
-// descriptor is allowed when using SPIR-V.
-#if TINT_BUILD_SPV_READER
-using ShaderModuleDescriptorBranches =
- BranchList<Branch<ShaderModuleWGSLDescriptor>,
- Branch<ShaderModuleSPIRVDescriptor, DawnShaderModuleSPIRVOptionsDescriptor>>;
-#else
-using ShaderModuleDescriptorBranches = BranchList<Branch<ShaderModuleWGSLDescriptor>>;
-#endif
-
MaybeError ValidateAndParseShaderModule(DeviceBase* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* outMessages) {
DAWN_ASSERT(parseResult != nullptr);
- UnpackedShaderModuleDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<ShaderModuleDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
wgpu::SType moduleType;
- DAWN_TRY_ASSIGN(moduleType, (ValidateBranches<ShaderModuleDescriptorBranches>(unpacked)));
+ // A WGSL (or SPIR-V, if enabled) subdescriptor is required, and a Dawn-specific SPIR-V options
+// descriptor is allowed when using SPIR-V.
+#if TINT_BUILD_SPV_READER
+ DAWN_TRY_ASSIGN(
+ moduleType,
+ (unpacked.ValidateBranches<
+ Branch<ShaderModuleWGSLDescriptor>,
+ Branch<ShaderModuleSPIRVDescriptor, DawnShaderModuleSPIRVOptionsDescriptor>>()));
+#else
+ DAWN_TRY_ASSIGN(moduleType, (unpacked.ValidateBranches<Branch<ShaderModuleWGSLDescriptor>>()));
+#endif
DAWN_ASSERT(moduleType != wgpu::SType::Invalid);
ScopedTintICEHandler scopedICEHandler(device);
@@ -1040,9 +1040,8 @@
case wgpu::SType::ShaderModuleSPIRVDescriptor: {
DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowSpirv),
"SPIR-V is disallowed.");
- const auto* spirvDesc = std::get<const ShaderModuleSPIRVDescriptor*>(unpacked);
- const auto* spirvOptions =
- std::get<const DawnShaderModuleSPIRVOptionsDescriptor*>(unpacked);
+ const auto* spirvDesc = unpacked.Get<ShaderModuleSPIRVDescriptor>();
+ const auto* spirvOptions = unpacked.Get<DawnShaderModuleSPIRVOptionsDescriptor>();
// TODO(dawn:2033): Avoid unnecessary copies of the SPIR-V code.
std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
@@ -1060,7 +1059,7 @@
}
#endif // TINT_BUILD_SPV_READER
case wgpu::SType::ShaderModuleWGSLDescriptor: {
- wgslDesc = std::get<const ShaderModuleWGSLDescriptor*>(unpacked);
+ wgslDesc = unpacked.Get<ShaderModuleWGSLDescriptor>();
break;
}
default:
diff --git a/src/dawn/native/d3d11/DeviceD3D11.cpp b/src/dawn/native/d3d11/DeviceD3D11.cpp
index 7e89ad9..7412ccf 100644
--- a/src/dawn/native/d3d11/DeviceD3D11.cpp
+++ b/src/dawn/native/d3d11/DeviceD3D11.cpp
@@ -254,14 +254,13 @@
ResultOrError<Ref<SharedTextureMemoryBase>> Device::ImportSharedTextureMemoryImpl(
const SharedTextureMemoryDescriptor* descriptor) {
- UnpackedSharedTextureMemoryDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<SharedTextureMemoryDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
wgpu::SType type;
DAWN_TRY_ASSIGN(
- type, (ValidateBranches<BranchList<Branch<SharedTextureMemoryDXGISharedHandleDescriptor>,
- Branch<SharedTextureMemoryD3D11Texture2DDescriptor>>>(
- unpacked)));
+ type, (unpacked.ValidateBranches<Branch<SharedTextureMemoryDXGISharedHandleDescriptor>,
+ Branch<SharedTextureMemoryD3D11Texture2DDescriptor>>()));
switch (type) {
case wgpu::SType::SharedTextureMemoryDXGISharedHandleDescriptor:
@@ -270,14 +269,14 @@
wgpu::FeatureName::SharedTextureMemoryDXGISharedHandle);
return SharedTextureMemory::Create(
this, descriptor->label,
- std::get<const SharedTextureMemoryDXGISharedHandleDescriptor*>(unpacked));
+ unpacked.Get<SharedTextureMemoryDXGISharedHandleDescriptor>());
case wgpu::SType::SharedTextureMemoryD3D11Texture2DDescriptor:
DAWN_INVALID_IF(!HasFeature(Feature::SharedTextureMemoryD3D11Texture2D),
"%s is not enabled.",
wgpu::FeatureName::SharedTextureMemoryD3D11Texture2D);
return SharedTextureMemory::Create(
this, descriptor->label,
- std::get<const SharedTextureMemoryD3D11Texture2DDescriptor*>(unpacked));
+ unpacked.Get<SharedTextureMemoryD3D11Texture2DDescriptor>());
default:
DAWN_UNREACHABLE();
}
@@ -285,21 +284,19 @@
ResultOrError<Ref<SharedFenceBase>> Device::ImportSharedFenceImpl(
const SharedFenceDescriptor* descriptor) {
- UnpackedSharedFenceDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<SharedFenceDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
wgpu::SType type;
- DAWN_TRY_ASSIGN(
- type,
- (ValidateBranches<BranchList<Branch<SharedFenceDXGISharedHandleDescriptor>>>(unpacked)));
+ DAWN_TRY_ASSIGN(type,
+ (unpacked.ValidateBranches<Branch<SharedFenceDXGISharedHandleDescriptor>>()));
switch (type) {
case wgpu::SType::SharedFenceDXGISharedHandleDescriptor:
DAWN_INVALID_IF(!HasFeature(Feature::SharedFenceDXGISharedHandle), "%s is not enabled.",
wgpu::FeatureName::SharedFenceDXGISharedHandle);
- return SharedFence::Create(
- this, descriptor->label,
- std::get<const SharedFenceDXGISharedHandleDescriptor*>(unpacked));
+ return SharedFence::Create(this, descriptor->label,
+ unpacked.Get<SharedFenceDXGISharedHandleDescriptor>());
default:
DAWN_UNREACHABLE();
}
diff --git a/src/dawn/native/d3d12/DeviceD3D12.cpp b/src/dawn/native/d3d12/DeviceD3D12.cpp
index 75c5ac5..4030715 100644
--- a/src/dawn/native/d3d12/DeviceD3D12.cpp
+++ b/src/dawn/native/d3d12/DeviceD3D12.cpp
@@ -478,13 +478,12 @@
ResultOrError<Ref<SharedTextureMemoryBase>> Device::ImportSharedTextureMemoryImpl(
const SharedTextureMemoryDescriptor* descriptor) {
- UnpackedSharedTextureMemoryDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<SharedTextureMemoryDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
wgpu::SType type;
DAWN_TRY_ASSIGN(
- type, (ValidateBranches<BranchList<Branch<SharedTextureMemoryDXGISharedHandleDescriptor>>>(
- unpacked)));
+ type, (unpacked.ValidateBranches<Branch<SharedTextureMemoryDXGISharedHandleDescriptor>>()));
switch (type) {
case wgpu::SType::SharedTextureMemoryDXGISharedHandleDescriptor:
@@ -493,7 +492,7 @@
wgpu::FeatureName::SharedTextureMemoryDXGISharedHandle);
return SharedTextureMemory::Create(
this, descriptor->label,
- std::get<const SharedTextureMemoryDXGISharedHandleDescriptor*>(unpacked));
+ unpacked.Get<SharedTextureMemoryDXGISharedHandleDescriptor>());
default:
DAWN_UNREACHABLE();
}
@@ -501,21 +500,19 @@
ResultOrError<Ref<SharedFenceBase>> Device::ImportSharedFenceImpl(
const SharedFenceDescriptor* descriptor) {
- UnpackedSharedFenceDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<SharedFenceDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
wgpu::SType type;
- DAWN_TRY_ASSIGN(
- type,
- (ValidateBranches<BranchList<Branch<SharedFenceDXGISharedHandleDescriptor>>>(unpacked)));
+ DAWN_TRY_ASSIGN(type,
+ (unpacked.ValidateBranches<Branch<SharedFenceDXGISharedHandleDescriptor>>()));
switch (type) {
case wgpu::SType::SharedFenceDXGISharedHandleDescriptor:
DAWN_INVALID_IF(!HasFeature(Feature::SharedFenceDXGISharedHandle), "%s is not enabled.",
wgpu::FeatureName::SharedFenceDXGISharedHandle);
- return SharedFence::Create(
- this, descriptor->label,
- std::get<const SharedFenceDXGISharedHandleDescriptor*>(unpacked));
+ return SharedFence::Create(this, descriptor->label,
+ unpacked.Get<SharedFenceDXGISharedHandleDescriptor>());
default:
DAWN_UNREACHABLE();
}
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
index b4c6246..1dc01ea 100644
--- a/src/dawn/native/vulkan/DeviceVk.cpp
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
@@ -226,20 +226,18 @@
ResultOrError<Ref<SharedTextureMemoryBase>> Device::ImportSharedTextureMemoryImpl(
const SharedTextureMemoryDescriptor* descriptor) {
- UnpackedSharedTextureMemoryDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<SharedTextureMemoryDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
wgpu::SType type;
- DAWN_TRY_ASSIGN(
- type, ValidateBranches<BranchList<Branch<SharedTextureMemoryDmaBufDescriptor>>>(unpacked));
+ DAWN_TRY_ASSIGN(type, unpacked.ValidateBranches<Branch<SharedTextureMemoryDmaBufDescriptor>>());
switch (type) {
case wgpu::SType::SharedTextureMemoryDmaBufDescriptor:
DAWN_INVALID_IF(!HasFeature(Feature::SharedTextureMemoryDmaBuf), "%s is not enabled.",
wgpu::FeatureName::SharedTextureMemoryDmaBuf);
- return SharedTextureMemory::Create(
- this, descriptor->label,
- std::get<const SharedTextureMemoryDmaBufDescriptor*>(unpacked));
+ return SharedTextureMemory::Create(this, descriptor->label,
+ unpacked.Get<SharedTextureMemoryDmaBufDescriptor>());
default:
DAWN_UNREACHABLE();
}
@@ -247,15 +245,14 @@
ResultOrError<Ref<SharedFenceBase>> Device::ImportSharedFenceImpl(
const SharedFenceDescriptor* descriptor) {
- UnpackedSharedFenceDescriptorChain unpacked;
- DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpackChain(descriptor));
+ Unpacked<SharedFenceDescriptor> unpacked;
+ DAWN_TRY_ASSIGN(unpacked, ValidateAndUnpack(descriptor));
wgpu::SType type;
DAWN_TRY_ASSIGN(
- type,
- (ValidateBranches<BranchList<Branch<SharedFenceVkSemaphoreZirconHandleDescriptor>,
- Branch<SharedFenceVkSemaphoreSyncFDDescriptor>,
- Branch<SharedFenceVkSemaphoreOpaqueFDDescriptor>>>(unpacked)));
+ type, (unpacked.ValidateBranches<Branch<SharedFenceVkSemaphoreZirconHandleDescriptor>,
+ Branch<SharedFenceVkSemaphoreSyncFDDescriptor>,
+ Branch<SharedFenceVkSemaphoreOpaqueFDDescriptor>>()));
switch (type) {
case wgpu::SType::SharedFenceVkSemaphoreZirconHandleDescriptor:
@@ -264,20 +261,18 @@
wgpu::FeatureName::SharedFenceVkSemaphoreZirconHandle);
return SharedFence::Create(
this, descriptor->label,
- std::get<const SharedFenceVkSemaphoreZirconHandleDescriptor*>(unpacked));
+ unpacked.Get<SharedFenceVkSemaphoreZirconHandleDescriptor>());
case wgpu::SType::SharedFenceVkSemaphoreSyncFDDescriptor:
DAWN_INVALID_IF(!HasFeature(Feature::SharedFenceVkSemaphoreSyncFD),
"%s is not enabled.", wgpu::FeatureName::SharedFenceVkSemaphoreSyncFD);
- return SharedFence::Create(
- this, descriptor->label,
- std::get<const SharedFenceVkSemaphoreSyncFDDescriptor*>(unpacked));
+ return SharedFence::Create(this, descriptor->label,
+ unpacked.Get<SharedFenceVkSemaphoreSyncFDDescriptor>());
case wgpu::SType::SharedFenceVkSemaphoreOpaqueFDDescriptor:
DAWN_INVALID_IF(!HasFeature(Feature::SharedFenceVkSemaphoreOpaqueFD),
"%s is not enabled.",
wgpu::FeatureName::SharedFenceVkSemaphoreOpaqueFD);
- return SharedFence::Create(
- this, descriptor->label,
- std::get<const SharedFenceVkSemaphoreOpaqueFDDescriptor*>(unpacked));
+ return SharedFence::Create(this, descriptor->label,
+ unpacked.Get<SharedFenceVkSemaphoreOpaqueFDDescriptor>());
default:
DAWN_UNREACHABLE();
}
diff --git a/src/dawn/tests/unittests/ChainUtilsTests.cpp b/src/dawn/tests/unittests/ChainUtilsTests.cpp
index 3f2fe89..47d3ca9 100644
--- a/src/dawn/tests/unittests/ChainUtilsTests.cpp
+++ b/src/dawn/tests/unittests/ChainUtilsTests.cpp
@@ -331,23 +331,29 @@
// TextureViewDescriptor (as of when this test was written) does not have any valid chains
// in the JSON nor via additional extensions.
TextureViewDescriptor desc;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- static_assert(std::tuple_size_v<decltype(unpacked)> == 0);
- std::apply(
- [](const auto*... args) {
- (([&](const auto* arg) { EXPECT_EQ(args, nullptr); }(args)), ...);
- },
- unpacked);
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ static_assert(std::tuple_size_v<decltype(unpacked)::TupleType> == 0);
+ EXPECT_TRUE(unpacked.Empty());
}
{
// InstanceDescriptor has at least 1 valid chain extension.
InstanceDescriptor desc;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- std::apply(
- [](const auto*... args) {
- (([&](const auto* arg) { EXPECT_EQ(args, nullptr); }(args)), ...);
- },
- unpacked);
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_TRUE(unpacked.Empty());
+ }
+ {
+ // SharedTextureMemoryProperties (as of when this test was written) does not have any valid
+ // chains in the JSON nor via additional extensions.
+ SharedTextureMemoryProperties properties;
+ auto unpacked = ValidateAndUnpack(&properties).AcquireSuccess();
+ static_assert(std::tuple_size_v<decltype(unpacked)::TupleType> == 0);
+ EXPECT_TRUE(unpacked.Empty());
+ }
+ {
+ // SharedFenceExportInfo has at least 1 valid chain extension.
+ SharedFenceExportInfo properties;
+ auto unpacked = ValidateAndUnpack(&properties).AcquireSuccess();
+ EXPECT_TRUE(unpacked.Empty());
}
}
@@ -359,7 +365,7 @@
TextureViewDescriptor desc;
ChainedStruct chain;
desc.nextInChain = &chain;
- EXPECT_THAT(ValidateAndUnpackChain(&desc).AcquireError()->GetFormattedMessage(),
+ EXPECT_THAT(ValidateAndUnpack(&desc).AcquireError()->GetFormattedMessage(),
HasSubstr("Unexpected"));
}
{
@@ -367,7 +373,7 @@
InstanceDescriptor desc;
ChainedStruct chain;
desc.nextInChain = &chain;
- EXPECT_THAT(ValidateAndUnpackChain(&desc).AcquireError()->GetFormattedMessage(),
+ EXPECT_THAT(ValidateAndUnpack(&desc).AcquireError()->GetFormattedMessage(),
HasSubstr("Unexpected"));
}
}
@@ -378,8 +384,26 @@
InstanceDescriptor desc;
DawnTogglesDescriptor chain;
desc.nextInChain = &chain;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- EXPECT_EQ(std::get<const DawnTogglesDescriptor*>(unpacked), &chain);
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ auto ext = unpacked.Get<DawnTogglesDescriptor>();
+ EXPECT_EQ(ext, &chain);
+
+ // For ChainedStructs, the resulting pointer from Get should be a const type.
+ static_assert(std::is_const_v<std::remove_reference_t<decltype(*ext)>>);
+}
+
+// Nominal unpacking valid descriptors should return the expected descriptors in the unpacked type.
+TEST(ChainUtilsTests, ValidateAndUnpackOut) {
+ // DawnAdapterPropertiesPowerPreference is a valid extension for AdapterProperties.
+ AdapterProperties properties;
+ DawnAdapterPropertiesPowerPreference chain;
+ properties.nextInChain = &chain;
+ auto unpacked = ValidateAndUnpack(&properties).AcquireSuccess();
+ auto ext = unpacked.Get<DawnAdapterPropertiesPowerPreference>();
+ EXPECT_EQ(ext, &chain);
+
+ // For ChainedStructOuts, the resulting pointer from Get should not be a const type.
+ static_assert(!std::is_const_v<std::remove_reference_t<decltype(*ext)>>);
}
// Duplicate valid extensions cause an error.
@@ -390,7 +414,19 @@
DawnTogglesDescriptor chain2;
desc.nextInChain = &chain1;
chain1.nextInChain = &chain2;
- EXPECT_THAT(ValidateAndUnpackChain(&desc).AcquireError()->GetFormattedMessage(),
+ EXPECT_THAT(ValidateAndUnpack(&desc).AcquireError()->GetFormattedMessage(),
+ HasSubstr("Duplicate"));
+}
+
+// Duplicate valid extensions cause an error.
+TEST(ChainUtilsTests, ValidateAndUnpackOutDuplicate) {
+ // DawnAdapterPropertiesPowerPreference is a valid extension for AdapterProperties.
+ AdapterProperties properties;
+ DawnAdapterPropertiesPowerPreference chain1;
+ DawnAdapterPropertiesPowerPreference chain2;
+ properties.nextInChain = &chain1;
+ chain1.nextInChain = &chain2;
+ EXPECT_THAT(ValidateAndUnpack(&properties).AcquireError()->GetFormattedMessage(),
HasSubstr("Duplicate"));
}
@@ -401,8 +437,8 @@
InstanceDescriptor desc;
DawnInstanceDescriptor chain;
desc.nextInChain = &chain;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- EXPECT_EQ(std::get<const DawnInstanceDescriptor*>(unpacked), &chain);
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_EQ(unpacked.Get<DawnInstanceDescriptor>(), &chain);
}
// Duplicate additional extensions added via template specialization should cause an error.
@@ -413,15 +449,13 @@
DawnInstanceDescriptor chain2;
desc.nextInChain = &chain1;
chain1.nextInChain = &chain2;
- EXPECT_THAT(ValidateAndUnpackChain(&desc).AcquireError()->GetFormattedMessage(),
+ EXPECT_THAT(ValidateAndUnpack(&desc).AcquireError()->GetFormattedMessage(),
HasSubstr("Duplicate"));
}
-using NoExtensionBranches =
- BranchList<Branch<ShaderModuleWGSLDescriptor>, Branch<ShaderModuleSPIRVDescriptor>>;
-using ExtensionBranches =
- BranchList<Branch<ShaderModuleWGSLDescriptor>,
- Branch<ShaderModuleSPIRVDescriptor, DawnShaderModuleSPIRVOptionsDescriptor>>;
+using B1 = Branch<ShaderModuleWGSLDescriptor>;
+using B2 = Branch<ShaderModuleSPIRVDescriptor>;
+using B2Ext = Branch<ShaderModuleSPIRVDescriptor, DawnShaderModuleSPIRVOptionsDescriptor>;
// Validates exacly 1 branch and ensures that there are no other extensions.
TEST(ChainUtilsTests, ValidateBranchesOneValidBranch) {
@@ -430,20 +464,20 @@
{
ShaderModuleWGSLDescriptor chain;
desc.nextInChain = &chain;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- EXPECT_EQ((ValidateBranches<NoExtensionBranches>(unpacked).AcquireSuccess()),
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_EQ((unpacked.ValidateBranches<B1, B2>().AcquireSuccess()),
wgpu::SType::ShaderModuleWGSLDescriptor);
}
{
ShaderModuleSPIRVDescriptor chain;
desc.nextInChain = &chain;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- EXPECT_EQ((ValidateBranches<NoExtensionBranches>(unpacked).AcquireSuccess()),
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_EQ((unpacked.ValidateBranches<B1, B2>().AcquireSuccess()),
wgpu::SType::ShaderModuleSPIRVDescriptor);
// Extensions are optional so validation should still pass when the extension is not
// provided.
- EXPECT_EQ((ValidateBranches<ExtensionBranches>(unpacked).AcquireSuccess()),
+ EXPECT_EQ((unpacked.ValidateBranches<B1, B2Ext>().AcquireSuccess()),
wgpu::SType::ShaderModuleSPIRVDescriptor);
}
}
@@ -453,9 +487,9 @@
ShaderModuleDescriptor desc;
DawnShaderModuleSPIRVOptionsDescriptor chain;
desc.nextInChain = &chain;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- EXPECT_NE((ValidateBranches<NoExtensionBranches>(unpacked).AcquireError()), nullptr);
- EXPECT_NE((ValidateBranches<ExtensionBranches>(unpacked).AcquireError()), nullptr);
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_NE((unpacked.ValidateBranches<B1, B2>().AcquireError()), nullptr);
+ EXPECT_NE((unpacked.ValidateBranches<B1, B2Ext>().AcquireError()), nullptr);
}
// Additional chains should cause an error when branches don't allow extensions.
@@ -466,17 +500,17 @@
DawnShaderModuleSPIRVOptionsDescriptor chain2;
desc.nextInChain = &chain1;
chain1.nextInChain = &chain2;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- EXPECT_NE((ValidateBranches<NoExtensionBranches>(unpacked).AcquireError()), nullptr);
- EXPECT_NE((ValidateBranches<ExtensionBranches>(unpacked).AcquireError()), nullptr);
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_NE((unpacked.ValidateBranches<B1, B2>().AcquireError()), nullptr);
+ EXPECT_NE((unpacked.ValidateBranches<B1, B2Ext>().AcquireError()), nullptr);
}
{
ShaderModuleSPIRVDescriptor chain1;
DawnShaderModuleSPIRVOptionsDescriptor chain2;
desc.nextInChain = &chain1;
chain1.nextInChain = &chain2;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- EXPECT_NE((ValidateBranches<NoExtensionBranches>(unpacked).AcquireError()), nullptr);
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_NE((unpacked.ValidateBranches<B1, B2>().AcquireError()), nullptr);
}
}
@@ -487,10 +521,168 @@
DawnShaderModuleSPIRVOptionsDescriptor chain2;
desc.nextInChain = &chain1;
chain1.nextInChain = &chain2;
- auto unpacked = ValidateAndUnpackChain(&desc).AcquireSuccess();
- EXPECT_EQ((ValidateBranches<ExtensionBranches>(unpacked).AcquireSuccess()),
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_EQ((unpacked.ValidateBranches<B1, B2Ext>().AcquireSuccess()),
wgpu::SType::ShaderModuleSPIRVDescriptor);
}
+// Unrealistic branching for ChainedStructOut testing. Note that this setup does not make sense.
+using BOut1 = Branch<SharedFenceVkSemaphoreOpaqueFDExportInfo>;
+using BOut2 = Branch<SharedFenceVkSemaphoreSyncFDExportInfo>;
+using BOut2Ext =
+ Branch<SharedFenceVkSemaphoreSyncFDExportInfo, SharedFenceVkSemaphoreZirconHandleExportInfo>;
+
+// Validates exacly 1 branch and ensures that there are no other extensions.
+TEST(ChainUtilsTests, ValidateBranchesOneValidBranchOut) {
+ SharedFenceExportInfo info;
+ // Either allowed branches should validate successfully and return the expected enum.
+ {
+ SharedFenceVkSemaphoreOpaqueFDExportInfo chain;
+ info.nextInChain = &chain;
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_EQ((unpacked.ValidateBranches<BOut1, BOut2>().AcquireSuccess()),
+ wgpu::SType::SharedFenceVkSemaphoreOpaqueFDExportInfo);
+ }
+ {
+ SharedFenceVkSemaphoreSyncFDExportInfo chain;
+ info.nextInChain = &chain;
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_EQ((unpacked.ValidateBranches<BOut1, BOut2>().AcquireSuccess()),
+ wgpu::SType::SharedFenceVkSemaphoreSyncFDExportInfo);
+
+ // Extensions are optional so validation should still pass when the extension is not
+ // provided.
+ EXPECT_EQ((unpacked.ValidateBranches<BOut1, BOut2Ext>().AcquireSuccess()),
+ wgpu::SType::SharedFenceVkSemaphoreSyncFDExportInfo);
+ }
+}
+
+// An allowed chain that is not one of the branches causes an error.
+TEST(ChainUtilsTests, ValidateBranchesInvalidBranchOut) {
+ SharedFenceExportInfo info;
+ SharedFenceDXGISharedHandleExportInfo chain;
+ info.nextInChain = &chain;
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_NE((unpacked.ValidateBranches<BOut1, BOut2>().AcquireError()), nullptr);
+ EXPECT_NE((unpacked.ValidateBranches<BOut1, BOut2Ext>().AcquireError()), nullptr);
+}
+
+// Additional chains should cause an error when branches don't allow extensions.
+TEST(ChainUtilsTests, ValidateBranchesInvalidExtensionOut) {
+ SharedFenceExportInfo info;
+ {
+ SharedFenceVkSemaphoreOpaqueFDExportInfo chain1;
+ SharedFenceVkSemaphoreZirconHandleExportInfo chain2;
+ info.nextInChain = &chain1;
+ chain1.nextInChain = &chain2;
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_NE((unpacked.ValidateBranches<BOut1, BOut2>().AcquireError()), nullptr);
+ EXPECT_NE((unpacked.ValidateBranches<BOut1, BOut2Ext>().AcquireError()), nullptr);
+ }
+ {
+ SharedFenceVkSemaphoreSyncFDExportInfo chain1;
+ SharedFenceVkSemaphoreZirconHandleExportInfo chain2;
+ info.nextInChain = &chain1;
+ chain1.nextInChain = &chain2;
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_NE((unpacked.ValidateBranches<BOut1, BOut2>().AcquireError()), nullptr);
+ }
+}
+
+// Branches that allow extensions pass successfully.
+TEST(ChainUtilsTests, ValidateBranchesAllowedExtensionsOut) {
+ SharedFenceExportInfo info;
+ SharedFenceVkSemaphoreSyncFDExportInfo chain1;
+ SharedFenceVkSemaphoreZirconHandleExportInfo chain2;
+ info.nextInChain = &chain1;
+ chain1.nextInChain = &chain2;
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_EQ((unpacked.ValidateBranches<BOut1, BOut2Ext>().AcquireSuccess()),
+ wgpu::SType::SharedFenceVkSemaphoreSyncFDExportInfo);
+}
+
+// Valid subsets should pass successfully, while invalid ones should error.
+TEST(ChainUtilsTests, ValidateSubset) {
+ DeviceDescriptor desc;
+ DawnTogglesDescriptor chain1;
+ DawnCacheDeviceDescriptor chain2;
+
+ // With none set, subset for anything should work.
+ {
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_TRUE(unpacked.ValidateSubset<DawnTogglesDescriptor>().IsSuccess());
+ EXPECT_TRUE(unpacked.ValidateSubset<DawnCacheDeviceDescriptor>().IsSuccess());
+ EXPECT_TRUE((unpacked.ValidateSubset<DawnTogglesDescriptor, DawnCacheDeviceDescriptor>()
+ .IsSuccess()));
+ }
+ // With one set, subset with that allow that one should pass. Otherwise it should fail.
+ {
+ desc.nextInChain = &chain1;
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_TRUE(unpacked.ValidateSubset<DawnTogglesDescriptor>().IsSuccess());
+ EXPECT_NE(unpacked.ValidateSubset<DawnCacheDeviceDescriptor>().AcquireError(), nullptr);
+ EXPECT_TRUE((unpacked.ValidateSubset<DawnTogglesDescriptor, DawnCacheDeviceDescriptor>()
+ .IsSuccess()));
+ }
+ // With both set, single subsets should all fail.
+ {
+ chain1.nextInChain = &chain2;
+ auto unpacked = ValidateAndUnpack(&desc).AcquireSuccess();
+ EXPECT_NE(unpacked.ValidateSubset<DawnTogglesDescriptor>().AcquireError(), nullptr);
+ EXPECT_NE(unpacked.ValidateSubset<DawnCacheDeviceDescriptor>().AcquireError(), nullptr);
+ EXPECT_TRUE((unpacked.ValidateSubset<DawnTogglesDescriptor, DawnCacheDeviceDescriptor>()
+ .IsSuccess()));
+ }
+}
+
+// Valid subsets should pass successfully, while invalid ones should error.
+TEST(ChainUtilsTests, ValidateSubsetOut) {
+ SharedFenceExportInfo info;
+ SharedFenceVkSemaphoreOpaqueFDExportInfo chain1;
+ SharedFenceVkSemaphoreZirconHandleExportInfo chain2;
+
+ // With none set, subset for anything should work.
+ {
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_TRUE(
+ unpacked.ValidateSubset<SharedFenceVkSemaphoreOpaqueFDExportInfo>().IsSuccess());
+ EXPECT_TRUE(
+ unpacked.ValidateSubset<SharedFenceVkSemaphoreZirconHandleExportInfo>().IsSuccess());
+ EXPECT_TRUE((unpacked
+ .ValidateSubset<SharedFenceVkSemaphoreOpaqueFDExportInfo,
+ SharedFenceVkSemaphoreZirconHandleExportInfo>()
+ .IsSuccess()));
+ }
+ // With one set, subset with that allow that one should pass. Otherwise it should fail.
+ {
+ info.nextInChain = &chain1;
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_TRUE(
+ unpacked.ValidateSubset<SharedFenceVkSemaphoreOpaqueFDExportInfo>().IsSuccess());
+ EXPECT_NE(
+ unpacked.ValidateSubset<SharedFenceVkSemaphoreZirconHandleExportInfo>().AcquireError(),
+ nullptr);
+ EXPECT_TRUE((unpacked
+ .ValidateSubset<SharedFenceVkSemaphoreOpaqueFDExportInfo,
+ SharedFenceVkSemaphoreZirconHandleExportInfo>()
+ .IsSuccess()));
+ }
+ // With both set, single subsets should all fail.
+ {
+ chain1.nextInChain = &chain2;
+ auto unpacked = ValidateAndUnpack(&info).AcquireSuccess();
+ EXPECT_NE(
+ unpacked.ValidateSubset<SharedFenceVkSemaphoreOpaqueFDExportInfo>().AcquireError(),
+ nullptr);
+ EXPECT_NE(
+ unpacked.ValidateSubset<SharedFenceVkSemaphoreZirconHandleExportInfo>().AcquireError(),
+ nullptr);
+ EXPECT_TRUE((unpacked
+ .ValidateSubset<SharedFenceVkSemaphoreOpaqueFDExportInfo,
+ SharedFenceVkSemaphoreZirconHandleExportInfo>()
+ .IsSuccess()));
+ }
+}
+
} // anonymous namespace
} // namespace dawn::native