[eval] Make a TransformTernaryElements
Make a specific non-templated version of the 3 element Transform method.
The generic `TransformElement` is restricted to only trigger if the
special `size` parameter is requested on the function template.
Reduces template instantiations in order to reduce binary size.
Change-Id: I922e75ee9208fe7e5e6ed54983a452ec92554195
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/148281
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/lang/core/constant/eval.cc b/src/tint/lang/core/constant/eval.cc
index 03a8eee..1aeadb1 100644
--- a/src/tint/lang/core/constant/eval.cc
+++ b/src/tint/lang/core/constant/eval.cc
@@ -462,23 +462,22 @@
return value_stack.Pop();
}
-namespace detail {
-/// Implementation of TransformElements
+/// TransformElements constructs a new constant of type `composite_ty` by applying the
+/// transformation function `f` on each of the most deeply nested elements of 'cs'. Assumes that all
+/// input constants `cs` are of the same arity (all scalars or all vectors of the same size).
+/// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside
+/// the most deeply nested aggregate type will be passed in.
template <typename F, typename... CONSTANTS>
-Eval::Result TransformElements(Manager& mgr,
- const core::type::Type* composite_ty,
- const F& f,
- size_t index,
- CONSTANTS&&... cs) {
+tint::traits::EnableIf<tint::traits::IsType<size_t, tint::traits::LastParameterType<F>>,
+ Eval::Result>
+TransformElements(Manager& mgr,
+ const core::type::Type* composite_ty,
+ const F& f,
+ size_t index,
+ CONSTANTS&&... cs) {
auto [el_ty, n] = First(cs...)->Type()->Elements();
if (!el_ty) {
- constexpr bool kHasIndexParam =
- tint::traits::IsType<size_t, tint::traits::LastParameterType<F>>;
- if constexpr (kHasIndexParam) {
- return f(cs..., index);
- } else {
- return f(cs...);
- }
+ return f(cs..., index);
}
auto* composite_el_ty = composite_ty->Elements(composite_ty).type;
@@ -486,8 +485,7 @@
Vector<const Value*, 8> els;
els.Reserve(n);
for (uint32_t i = 0; i < n; i++) {
- if (auto el =
- detail::TransformElements(mgr, composite_el_ty, f, index + i, cs->Index(i)...)) {
+ if (auto el = TransformElements(mgr, composite_el_ty, f, index + i, cs->Index(i)...)) {
els.Push(el.Get());
} else {
@@ -496,20 +494,6 @@
}
return mgr.Composite(composite_ty, std::move(els));
}
-} // namespace detail
-
-/// TransformElements constructs a new constant of type `composite_ty` by applying the
-/// transformation function `f` on each of the most deeply nested elements of 'cs'. Assumes that all
-/// input constants `cs` are of the same arity (all scalars or all vectors of the same size).
-/// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside
-/// the most deeply nested aggregate type will be passed in.
-template <typename F, typename... CONSTANTS>
-Eval::Result TransformElements(Manager& mgr,
- const core::type::Type* composite_ty,
- const F& f,
- CONSTANTS&&... cs) {
- return detail::TransformElements(mgr, composite_ty, f, 0, cs...);
-}
/// Signature of a unary transformation callback
using UnaryTransform = std::function<Eval::Result(const Value*)>;
@@ -607,6 +591,39 @@
}
return mgr.Composite(composite_ty, std::move(els));
}
+
+/// Signature of a ternary transformation callback
+using TernaryTransform = std::function<Eval::Result(const Value*, const Value*, const Value*)>;
+
+/// TransformTernaryElements constructs a new constant of type `composite_ty` by applying the
+/// transformation function 'f' on each of the most deeply nested elements of both `c0`, `c1`, and
+/// `c2`.
+Eval::Result TransformTernaryElements(Manager& mgr,
+ const core::type::Type* composite_ty,
+ const TernaryTransform& f,
+ const Value* c0,
+ const Value* c1,
+ const Value* c2) {
+ auto [el_ty, n] = c0->Type()->Elements();
+ if (!el_ty) {
+ return f(c0, c1, c2);
+ }
+
+ auto* composite_el_ty = composite_ty->Elements(composite_ty).type;
+
+ Vector<const Value*, 8> els;
+ els.Reserve(n);
+ for (uint32_t i = 0; i < n; i++) {
+ if (auto el = TransformTernaryElements(mgr, composite_el_ty, f, c0->Index(i), c1->Index(i),
+ c2->Index(i))) {
+ els.Push(el.Get());
+
+ } else {
+ return el.Failure();
+ }
+ }
+ return mgr.Composite(composite_ty, std::move(els));
+}
} // namespace
Eval::Eval(Manager& manager, diag::List& diagnostics, bool use_runtime_semantics /* = false */)
@@ -2307,7 +2324,7 @@
auto transform = [&](const Value* c0, const Value* c1, const Value* c2) {
return Dispatch_fia_fiu32_f16(ClampFunc(source, c0->Type()), c0, c1, c2);
};
- return TransformElements(mgr, ty, transform, args[0], args[1], args[2]);
+ return TransformTernaryElements(mgr, ty, transform, args[0], args[1], args[2]);
}
Eval::Result Eval::cos(const core::type::Type* ty,
@@ -2752,7 +2769,7 @@
};
return Dispatch_fa_f32_f16(create, c1, c2, c3);
};
- return TransformElements(mgr, ty, transform, args[0], args[1], args[2]);
+ return TransformTernaryElements(mgr, ty, transform, args[0], args[1], args[2]);
}
Eval::Result Eval::fract(const core::type::Type* ty,
@@ -2976,7 +2993,7 @@
return Dispatch_fa_f32_f16(create, c1);
};
- return TransformElements(mgr, ty, transform, args[0]);
+ return TransformElements(mgr, ty, transform, 0, args[0]);
}
Eval::Result Eval::length(const core::type::Type* ty,
@@ -3091,7 +3108,7 @@
};
return Dispatch_fa_f32_f16(create, c0, c1);
};
- auto r = TransformElements(mgr, ty, transform, args[0], args[1]);
+ auto r = TransformElements(mgr, ty, transform, 0, args[0], args[1]);
if (!r) {
AddNote("when calculating mix", source);
}
@@ -3547,7 +3564,7 @@
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
- return TransformElements(mgr, ty, transform, args[0], args[1]);
+ return TransformElements(mgr, ty, transform, 0, args[0], args[1]);
}
Eval::Result Eval::sign(const core::type::Type* ty,
@@ -3645,7 +3662,7 @@
};
return Dispatch_fa_f32_f16(create, c0, c1, c2);
};
- return TransformElements(mgr, ty, transform, args[0], args[1], args[2]);
+ return TransformTernaryElements(mgr, ty, transform, args[0], args[1], args[2]);
}
Eval::Result Eval::step(const core::type::Type* ty,