[eval] Make a TransformTernaryElements

Make a specific non-templated version of the 3 element Transform method.
The generic `TransformElement` is restricted to only trigger if the
special `size` parameter is requested on the function template.

Reduces template instantiations in order to reduce binary size.

Change-Id: I922e75ee9208fe7e5e6ed54983a452ec92554195
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/148281
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/lang/core/constant/eval.cc b/src/tint/lang/core/constant/eval.cc
index 03a8eee..1aeadb1 100644
--- a/src/tint/lang/core/constant/eval.cc
+++ b/src/tint/lang/core/constant/eval.cc
@@ -462,23 +462,22 @@
     return value_stack.Pop();
 }
 
-namespace detail {
-/// Implementation of TransformElements
+/// TransformElements constructs a new constant of type `composite_ty` by applying the
+/// transformation function `f` on each of the most deeply nested elements of 'cs'. Assumes that all
+/// input constants `cs` are of the same arity (all scalars or all vectors of the same size).
+/// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside
+/// the most deeply nested aggregate type will be passed in.
 template <typename F, typename... CONSTANTS>
-Eval::Result TransformElements(Manager& mgr,
-                               const core::type::Type* composite_ty,
-                               const F& f,
-                               size_t index,
-                               CONSTANTS&&... cs) {
+tint::traits::EnableIf<tint::traits::IsType<size_t, tint::traits::LastParameterType<F>>,
+                       Eval::Result>
+TransformElements(Manager& mgr,
+                  const core::type::Type* composite_ty,
+                  const F& f,
+                  size_t index,
+                  CONSTANTS&&... cs) {
     auto [el_ty, n] = First(cs...)->Type()->Elements();
     if (!el_ty) {
-        constexpr bool kHasIndexParam =
-            tint::traits::IsType<size_t, tint::traits::LastParameterType<F>>;
-        if constexpr (kHasIndexParam) {
-            return f(cs..., index);
-        } else {
-            return f(cs...);
-        }
+        return f(cs..., index);
     }
 
     auto* composite_el_ty = composite_ty->Elements(composite_ty).type;
@@ -486,8 +485,7 @@
     Vector<const Value*, 8> els;
     els.Reserve(n);
     for (uint32_t i = 0; i < n; i++) {
-        if (auto el =
-                detail::TransformElements(mgr, composite_el_ty, f, index + i, cs->Index(i)...)) {
+        if (auto el = TransformElements(mgr, composite_el_ty, f, index + i, cs->Index(i)...)) {
             els.Push(el.Get());
 
         } else {
@@ -496,20 +494,6 @@
     }
     return mgr.Composite(composite_ty, std::move(els));
 }
-}  // namespace detail
-
-/// TransformElements constructs a new constant of type `composite_ty` by applying the
-/// transformation function `f` on each of the most deeply nested elements of 'cs'. Assumes that all
-/// input constants `cs` are of the same arity (all scalars or all vectors of the same size).
-/// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside
-/// the most deeply nested aggregate type will be passed in.
-template <typename F, typename... CONSTANTS>
-Eval::Result TransformElements(Manager& mgr,
-                               const core::type::Type* composite_ty,
-                               const F& f,
-                               CONSTANTS&&... cs) {
-    return detail::TransformElements(mgr, composite_ty, f, 0, cs...);
-}
 
 /// Signature of a unary transformation callback
 using UnaryTransform = std::function<Eval::Result(const Value*)>;
@@ -607,6 +591,39 @@
     }
     return mgr.Composite(composite_ty, std::move(els));
 }
+
+/// Signature of a ternary transformation callback
+using TernaryTransform = std::function<Eval::Result(const Value*, const Value*, const Value*)>;
+
+/// TransformTernaryElements constructs a new constant of type `composite_ty` by applying the
+/// transformation function 'f' on each of the most deeply nested elements of both `c0`, `c1`, and
+/// `c2`.
+Eval::Result TransformTernaryElements(Manager& mgr,
+                                      const core::type::Type* composite_ty,
+                                      const TernaryTransform& f,
+                                      const Value* c0,
+                                      const Value* c1,
+                                      const Value* c2) {
+    auto [el_ty, n] = c0->Type()->Elements();
+    if (!el_ty) {
+        return f(c0, c1, c2);
+    }
+
+    auto* composite_el_ty = composite_ty->Elements(composite_ty).type;
+
+    Vector<const Value*, 8> els;
+    els.Reserve(n);
+    for (uint32_t i = 0; i < n; i++) {
+        if (auto el = TransformTernaryElements(mgr, composite_el_ty, f, c0->Index(i), c1->Index(i),
+                                               c2->Index(i))) {
+            els.Push(el.Get());
+
+        } else {
+            return el.Failure();
+        }
+    }
+    return mgr.Composite(composite_ty, std::move(els));
+}
 }  // namespace
 
 Eval::Eval(Manager& manager, diag::List& diagnostics, bool use_runtime_semantics /* = false */)
@@ -2307,7 +2324,7 @@
     auto transform = [&](const Value* c0, const Value* c1, const Value* c2) {
         return Dispatch_fia_fiu32_f16(ClampFunc(source, c0->Type()), c0, c1, c2);
     };
-    return TransformElements(mgr, ty, transform, args[0], args[1], args[2]);
+    return TransformTernaryElements(mgr, ty, transform, args[0], args[1], args[2]);
 }
 
 Eval::Result Eval::cos(const core::type::Type* ty,
@@ -2752,7 +2769,7 @@
         };
         return Dispatch_fa_f32_f16(create, c1, c2, c3);
     };
-    return TransformElements(mgr, ty, transform, args[0], args[1], args[2]);
+    return TransformTernaryElements(mgr, ty, transform, args[0], args[1], args[2]);
 }
 
 Eval::Result Eval::fract(const core::type::Type* ty,
@@ -2976,7 +2993,7 @@
         return Dispatch_fa_f32_f16(create, c1);
     };
 
-    return TransformElements(mgr, ty, transform, args[0]);
+    return TransformElements(mgr, ty, transform, 0, args[0]);
 }
 
 Eval::Result Eval::length(const core::type::Type* ty,
@@ -3091,7 +3108,7 @@
         };
         return Dispatch_fa_f32_f16(create, c0, c1);
     };
-    auto r = TransformElements(mgr, ty, transform, args[0], args[1]);
+    auto r = TransformElements(mgr, ty, transform, 0, args[0], args[1]);
     if (!r) {
         AddNote("when calculating mix", source);
     }
@@ -3547,7 +3564,7 @@
         return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
     };
 
-    return TransformElements(mgr, ty, transform, args[0], args[1]);
+    return TransformElements(mgr, ty, transform, 0, args[0], args[1]);
 }
 
 Eval::Result Eval::sign(const core::type::Type* ty,
@@ -3645,7 +3662,7 @@
         };
         return Dispatch_fa_f32_f16(create, c0, c1, c2);
     };
-    return TransformElements(mgr, ty, transform, args[0], args[1], args[2]);
+    return TransformTernaryElements(mgr, ty, transform, args[0], args[1], args[2]);
 }
 
 Eval::Result Eval::step(const core::type::Type* ty,