tint: Implement Switch() without recursion
This removes the number of function calls made in non-optimized builds.
Reduces the optimized, all-features-enabled 'tint' executable size by
about 1%.
This change removes the bloom filter optimizations which provided
substantial performance gains with the old recursive implementation,
however this still appears to be ~1% faster than the optimized version.
Change-Id: Ic2bb82e9182459e37907f9e0d0b4771bde218f9f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/123440
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Kokoro: Ben Clayton <bclayton@chromium.org>
diff --git a/src/tint/castable.h b/src/tint/castable.h
index acb9a18..71e3ab8 100644
--- a/src/tint/castable.h
+++ b/src/tint/castable.h
@@ -21,9 +21,7 @@
#include <utility>
#include "src/tint/traits.h"
-#include "src/tint/utils/bitcast.h"
#include "src/tint/utils/crc32.h"
-#include "src/tint/utils/defer.h"
#if defined(__clang__)
/// Temporarily disable certain warnings when using Castable API
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index e4eab291a..693e1b3 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -42,6 +42,7 @@
#include "src/tint/type/multisampled_texture.h"
#include "src/tint/type/sampled_texture.h"
#include "src/tint/type/texture_dimension.h"
+#include "src/tint/utils/defer.h"
#include "src/tint/utils/reverse.h"
#include "src/tint/utils/string.h"
#include "src/tint/utils/string_stream.h"
diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc
index cf53c4c..200221a 100644
--- a/src/tint/resolver/uniformity.cc
+++ b/src/tint/resolver/uniformity.cc
@@ -39,6 +39,7 @@
#include "src/tint/sem/while_statement.h"
#include "src/tint/switch.h"
#include "src/tint/utils/block_allocator.h"
+#include "src/tint/utils/defer.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/string_stream.h"
#include "src/tint/utils/unique_vector.h"
diff --git a/src/tint/switch.h b/src/tint/switch.h
index 9ae8d3b..5116195 100644
--- a/src/tint/switch.h
+++ b/src/tint/switch.h
@@ -19,6 +19,8 @@
#include <utility>
#include "src/tint/castable.h"
+#include "src/tint/utils/bitcast.h"
+#include "src/tint/utils/defer.h"
namespace tint {
@@ -62,126 +64,6 @@
}
}
-/// The implementation of Switch() for non-Default cases.
-/// Switch splits the cases into two a low and high block of cases, and quickly rules out blocks
-/// that cannot match by comparing the HashCode of the object and the cases in the block. If a block
-/// of cases may match the given object's type, then that block is split into two, and the process
-/// recurses. When NonDefaultCases() is called with a single case, then As<> will be used to
-/// dynamically cast to the case type and if the cast succeeds, then the case handler is called.
-/// @returns true if a case handler was found, otherwise false.
-template <typename T, typename RETURN_TYPE, typename... CASES>
-inline bool NonDefaultCases([[maybe_unused]] T* object,
- const TypeInfo* type,
- [[maybe_unused]] RETURN_TYPE* result,
- std::tuple<CASES...>&& cases) {
- using Cases = std::tuple<CASES...>;
-
- static constexpr bool kHasReturnType = !std::is_same_v<RETURN_TYPE, void>;
- static constexpr size_t kNumCases = sizeof...(CASES);
-
- if constexpr (kNumCases == 0) {
- // No cases. Nothing to do.
- return false;
- } else if constexpr (kNumCases == 1) { // NOLINT: cpplint doesn't understand
- // `else if constexpr`
- // Single case.
- using CaseFunc = std::tuple_element_t<0, Cases>;
- static_assert(!IsDefaultCase<CaseFunc>, "NonDefaultCases called with a Default case");
- // Attempt to dynamically cast the object to the handler type. If that succeeds, call the
- // case handler with the cast object.
- using CaseType = SwitchCaseType<CaseFunc>;
- if (type->Is<CaseType>()) {
- auto* ptr = static_cast<CaseType*>(object);
- if constexpr (kHasReturnType) {
- new (result) RETURN_TYPE(static_cast<RETURN_TYPE>(std::get<0>(cases)(ptr)));
- } else {
- std::get<0>(cases)(ptr);
- }
- return true;
- }
- return false;
- } else {
- // Multiple cases.
- // Check the hashcode bits to see if there's any possibility of a case matching in these
- // cases. If there isn't, we can skip all these cases.
- if (MaybeAnyOf(TypeInfo::CombinedHashCodeOf<SwitchCaseType<CASES>...>(),
- type->full_hashcode)) {
- // Split the cases into two, and recurse.
- constexpr size_t kMid = kNumCases / 2;
- return NonDefaultCases(object, type, result, traits::Slice<0, kMid>(cases)) ||
- NonDefaultCases(object, type, result,
- traits::Slice<kMid, kNumCases - kMid>(cases));
- } else {
- return false;
- }
- }
-}
-
-/// The implementation of Switch() for all cases.
-/// @see NonDefaultCases
-template <typename T, typename RETURN_TYPE, typename... CASES>
-inline void SwitchCases(T* object, RETURN_TYPE* result, std::tuple<CASES...>&& cases) {
- using Cases = std::tuple<CASES...>;
-
- static constexpr int kDefaultIndex = detail::IndexOfDefaultCase<Cases>();
- static constexpr bool kHasDefaultCase = kDefaultIndex >= 0;
- static constexpr bool kHasReturnType = !std::is_same_v<RETURN_TYPE, void>;
-
- // Static assertions
- static constexpr bool kDefaultIsOK =
- kDefaultIndex == -1 || kDefaultIndex == static_cast<int>(std::tuple_size_v<Cases> - 1);
- static constexpr bool kReturnIsOK =
- kHasDefaultCase || !kHasReturnType || std::is_constructible_v<RETURN_TYPE>;
- static_assert(kDefaultIsOK, "Default case must be last in Switch()");
- static_assert(kReturnIsOK,
- "Switch() requires either a Default case or a return type that is either void or "
- "default-constructable");
-
- // If the static asserts have fired, don't bother spewing more errors below
- static constexpr bool kAllOK = kDefaultIsOK && kReturnIsOK;
- if constexpr (kAllOK) {
- if (object) {
- auto* type = &object->TypeInfo();
- if constexpr (kHasDefaultCase) {
- // Evaluate non-default cases.
- if (!detail::NonDefaultCases<T>(object, type, result,
- traits::Slice<0, kDefaultIndex>(cases))) {
- // Nothing matched. Evaluate default case.
- if constexpr (kHasReturnType) {
- new (result) RETURN_TYPE(
- static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})));
- } else {
- std::get<kDefaultIndex>(cases)({});
- }
- }
- } else {
- if (!detail::NonDefaultCases<T>(object, type, result, std::move(cases))) {
- // Nothing matched. No default case.
- if constexpr (kHasReturnType) {
- new (result) RETURN_TYPE();
- }
- }
- }
- } else {
- // Object is nullptr, so no cases can match
- if constexpr (kHasDefaultCase) {
- // Evaluate default case.
- if constexpr (kHasReturnType) {
- new (result)
- RETURN_TYPE(static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({})));
- } else {
- std::get<kDefaultIndex>(cases)({});
- }
- } else {
- // No default case, no case can match.
- if constexpr (kHasReturnType) {
- new (result) RETURN_TYPE();
- }
- }
- }
- }
-}
-
/// Resolves to T if T is not nullptr_t, otherwise resolves to Ignore.
template <typename T>
using NullptrToIgnore = std::conditional_t<std::is_same_v<T, std::nullptr_t>, Ignore, T>;
@@ -282,21 +164,95 @@
template <typename RETURN_TYPE = detail::Infer, typename T = CastableBase, typename... CASES>
inline auto Switch(T* object, CASES&&... cases) {
using ReturnType = detail::SwitchReturnType<RETURN_TYPE, traits::ReturnType<CASES>...>;
+ static constexpr int kDefaultIndex = detail::IndexOfDefaultCase<std::tuple<CASES...>>();
+ static constexpr bool kHasDefaultCase = kDefaultIndex >= 0;
static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
+ // Static assertions
+ static constexpr bool kDefaultIsOK =
+ kDefaultIndex == -1 || kDefaultIndex == static_cast<int>(sizeof...(CASES) - 1);
+ static constexpr bool kReturnIsOK =
+ kHasDefaultCase || !kHasReturnType || std::is_constructible_v<ReturnType>;
+ static_assert(kDefaultIsOK, "Default case must be last in Switch()");
+ static_assert(kReturnIsOK,
+ "Switch() requires either a Default case or a return type that is either void or "
+ "default-constructable");
+
+ if (!object) { // Object is nullptr, so no cases can match
+ if constexpr (kHasDefaultCase) {
+ // Evaluate default case.
+ auto&& default_case =
+ std::get<kDefaultIndex>(std::forward_as_tuple(std::forward<CASES>(cases)...));
+ return static_cast<ReturnType>(default_case(Default{}));
+ } else {
+ // No default case, no case can match.
+ if constexpr (kHasReturnType) {
+ return ReturnType{};
+ } else {
+ return;
+ }
+ }
+ }
+
+ // Replacement for std::aligned_storage as this is broken on earlier versions of MSVC.
+ using ReturnTypeOrU8 = std::conditional_t<kHasReturnType, ReturnType, uint8_t>;
+ struct alignas(alignof(ReturnTypeOrU8)) ReturnStorage {
+ uint8_t data[sizeof(ReturnTypeOrU8)];
+ };
+ ReturnStorage storage;
+ auto* result = utils::Bitcast<ReturnTypeOrU8*>(&storage);
+
+ const TypeInfo& type_info = object->TypeInfo();
+
+ // Examines the parameter type of the case function.
+ // If the parameter is a pointer type that `object` is of, or derives from, then that case
+ // function is called with `object` cast to that type, and `try_case` returns true.
+ // If the parameter is of type `Default`, then that case function is called and `try_case`
+ // returns true.
+ // Otherwise `try_case` returns false.
+ // If the case function is called and it returns a value, then this is copy constructed to the
+ // `result` pointer.
+ auto try_case = [&](auto&& case_fn) {
+ using CaseFunc = std::decay_t<decltype(case_fn)>;
+ using CaseType = detail::SwitchCaseType<CaseFunc>;
+ if constexpr (std::is_same_v<CaseType, Default>) {
+ if constexpr (kHasReturnType) {
+ new (result) ReturnType(static_cast<ReturnType>(case_fn(Default{})));
+ } else {
+ case_fn(Default{});
+ }
+ return true;
+ } else {
+ if (type_info.Is<CaseType>()) {
+ auto* v = static_cast<CaseType*>(object);
+ if constexpr (kHasReturnType) {
+ new (result) ReturnType(static_cast<ReturnType>(case_fn(v)));
+ } else {
+ case_fn(v);
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+
+ // Use a logical-or fold expression to try each of the cases in turn, until one matches the
+ // object type or a Default is reached. `handled` is true if a case function was called.
+ bool handled = ((try_case(std::forward<CASES>(cases)) || ...));
+
if constexpr (kHasReturnType) {
- // Replacement for std::aligned_storage as this is broken on earlier versions of MSVC.
- struct alignas(alignof(ReturnType)) ReturnStorage {
- uint8_t data[sizeof(ReturnType)];
- };
- ReturnStorage storage;
- auto* res = utils::Bitcast<ReturnType*>(&storage);
- TINT_DEFER(res->~ReturnType());
- detail::SwitchCases(object, res, std::forward_as_tuple(std::forward<CASES>(cases)...));
- return *res;
- } else {
- detail::SwitchCases<T, void>(object, nullptr,
- std::forward_as_tuple(std::forward<CASES>(cases)...));
+ if constexpr (kHasDefaultCase) {
+ // Default case means there must be a returned value.
+ // No need to check handled, no requirement for a zero-initializer of ReturnType.
+ TINT_DEFER(result->~ReturnType());
+ return *result;
+ } else {
+ if (handled) {
+ TINT_DEFER(result->~ReturnType());
+ return *result;
+ }
+ return ReturnType{};
+ }
}
}
diff --git a/src/tint/utils/slice.h b/src/tint/utils/slice.h
index 719a53e..325c470 100644
--- a/src/tint/utils/slice.h
+++ b/src/tint/utils/slice.h
@@ -20,6 +20,7 @@
#include "src/tint/castable.h"
#include "src/tint/traits.h"
+#include "src/tint/utils/bitcast.h"
namespace tint::utils {
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index 0a167eb..f98cf86 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -35,6 +35,7 @@
#include "src/tint/sem/struct.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/switch.h"
+#include "src/tint/utils/defer.h"
#include "src/tint/utils/math.h"
#include "src/tint/utils/scoped_assignment.h"
#include "src/tint/writer/float_to_string.h"