blob: 0146abf60edc83cabdb260e1e5dd97e61f75b8e7 [file] [log] [blame]
// Copyright 2023 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_DAWN_NATIVE_CHAINUTILS_H_
#define SRC_DAWN_NATIVE_CHAINUTILS_H_
#include <bitset>
#include <string>
#include <tuple>
#include "absl/strings/str_format.h"
#include "dawn/common/Math.h"
#include "dawn/native/ChainUtils_autogen.h"
#include "dawn/native/Error.h"
namespace dawn::native {
// Tuple type of a Branch and an optional list of corresponding Extensions.
template <typename B, typename... Exts>
struct Branch;
// Typelist type used to specify a list of acceptable Branches.
template <typename... Branches>
struct BranchList;
namespace detail {
// Helpers to get the index in an unpacked tuple type for a particular type.
template <typename Unpacked, typename Ext>
struct UnpackedIndexOf;
template <typename Ext, typename... Exts>
struct UnpackedIndexOf<std::tuple<const Ext*, Exts...>, Ext> {
static constexpr size_t value = 0;
};
template <typename Ext, typename... Exts>
struct UnpackedIndexOf<std::tuple<Ext, Exts...>, Ext> {
static constexpr size_t value = 0;
};
template <typename Ext, typename Other, typename... Exts>
struct UnpackedIndexOf<std::tuple<Other, Exts...>, Ext> {
static constexpr size_t value = 1 + UnpackedIndexOf<std::tuple<Exts...>, Ext>::value;
};
template <typename Unpacked, typename... Exts>
struct UnpackedBitsetForExts {
// Currently using an internal 64-bit unsigned int for internal representation. This is
// necessary because std::bitset::operator| is not constexpr until C++23.
static constexpr auto value = std::bitset<std::tuple_size_v<Unpacked>>(
((uint64_t(1) << UnpackedIndexOf<Unpacked, Exts>::value) | ...));
};
template <typename Unpacked, typename...>
struct OneBranchValidator;
template <typename Unpacked, typename B, typename... Exts>
struct OneBranchValidator<Unpacked, Branch<B, Exts...>> {
static bool Validate(const Unpacked& unpacked,
const std::bitset<std::tuple_size_v<Unpacked>>& actual,
wgpu::SType& match) {
// Only check the full bitset when the main branch matches.
if (std::get<const B*>(unpacked) != nullptr) {
// Allowed set of extensions includes the branch root as well.
constexpr auto allowed = UnpackedBitsetForExts<Unpacked, B, Exts...>::value;
// The configuration is allowed if the actual available chains are a subset.
if (IsSubset(actual, allowed)) {
match = STypeFor<B>;
return true;
}
}
return false;
}
static std::string ToString() {
if constexpr (sizeof...(Exts) > 0) {
return absl::StrFormat("[ %s -> (%s) ]", STypesToString<B>(),
STypesToString<Exts...>());
} else {
return absl::StrFormat("[ %s ]", STypesToString<B>());
}
}
};
template <typename Unpacked, typename List>
struct BranchesValidator;
template <typename Unpacked, typename... Branches>
struct BranchesValidator<Unpacked, BranchList<Branches...>> {
static bool Validate(const Unpacked& unpacked, wgpu::SType& match) {
// Build a bitset based on which elements in the tuple are actually set. We are essentially
// just looping over every element in the unpacked tuple, computing the index of the element
// within the tuple, and setting the respective bit if the element is not nullptr.
std::bitset<std::tuple_size_v<Unpacked>> actual;
std::apply(
[&](const auto*... args) {
(actual.set(UnpackedIndexOf<Unpacked, decltype(args)>::value, args != nullptr),
...);
},
unpacked);
return ((OneBranchValidator<Unpacked, Branches>::Validate(unpacked, actual, match)) || ...);
}
static std::string ToString() {
return ((absl::StrFormat(" - %s\n", OneBranchValidator<Unpacked, Branches>::ToString())) +
...);
}
};
} // namespace detail
// Helper to validate that an unpacked chain retrieved via ValidateAndUnpackChain matches a valid
// "branch", where a "branch" is defined as a "root" extension and optional follow-up extensions.
// Returns the wgpu::SType associated with the "root" extension of a "branch" if matched, otherwise
// returns an error.
//
// Example usage:
// UnpackedChain u;
// DAWN_TRY_ASSIGN(u, ValidateAndUnpackChain(desc));
// wgpu::SType rootType;
// DAWN_TRY_ASSIGN(rootType,
// ValidateBranches<BranchList<Branch<Root1>, Branch<Root2, R2Ext1>>>(u));
// switch (rootType) {
// case Root1: {
// <do something>
// }
// case Root2: {
// R2Ext1 ext = std::get<const R2Ext1*>(u);
// if (ext) {
// <do something with optional extension(s)>
// }
// }
// default:
// DAWN_UNREACHABLE();
// }
//
// The example above checks that the unpacked chain is either:
// - only a Root1 extension
// - or a Root2 extension with an optional R2Ext1 extension
// Any other configuration is deemed invalid.
template <typename Branches, typename Unpacked>
ResultOrError<wgpu::SType> ValidateBranches(const Unpacked& unpacked) {
using Validator = detail::BranchesValidator<Unpacked, Branches>;
wgpu::SType match = wgpu::SType::Invalid;
if (Validator::Validate(unpacked, match)) {
return match;
}
return DAWN_VALIDATION_ERROR(
"Expected chain root to match one of the following branch types with optional extensions:\n"
"%sInstead found: %s",
Validator::ToString(), detail::UnpackedChainToString(unpacked));
}
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_CHAINUTILS_H_