traits: Replace FirstParamType with ParamType

Allows you to infer the N'th parameter type of a function.

Change-Id: Iab7065cb37dbf1332cef601bca91894b8c6b4edf
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35662
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/clone_context.h b/src/ast/clone_context.h
index 05f11d6..a9e303b 100644
--- a/src/ast/clone_context.h
+++ b/src/ast/clone_context.h
@@ -119,7 +119,7 @@
   ///        `T* (T*)`, where `T` derives from CastableBase
   template <typename F>
   void ReplaceAll(F replacer) {
-    using TPtr = traits::FirstParamTypeT<F>;
+    using TPtr = traits::ParamTypeT<F, 0>;
     using T = typename std::remove_pointer<TPtr>::type;
     transforms_.emplace_back([=](CastableBase* in) {
       auto* in_as_t = in->As<T>();
diff --git a/src/ast/traits.h b/src/ast/traits.h
index bb1f999..7e27036 100644
--- a/src/ast/traits.h
+++ b/src/ast/traits.h
@@ -15,45 +15,55 @@
 #ifndef SRC_AST_TRAITS_H_
 #define SRC_AST_TRAITS_H_
 
+#include <tuple>
 #include <type_traits>
 
 namespace tint {
 namespace ast {
 namespace traits {
 
-/// FirstParamType is a traits helper that infers the type of the first
-/// parameter of the function, method, static method, lambda, or function-like
-/// object `F`.
-template <typename F>
-struct FirstParamType {
-  /// The type of the first parameter of the function-like object `F`
-  using type = typename FirstParamType<decltype(&F::operator())>::type;
+/// NthTypeOf returns the `N`th type in `Types`
+template <int N, typename... Types>
+using NthTypeOf = typename std::tuple_element<N, std::tuple<Types...>>::type;
+
+/// ParamType is a traits helper that infers the type of the `N`th parameter
+/// of the function, method, static method, lambda, or function-like object `F`.
+template <typename F, int N>
+struct ParamType {
+  /// The type of the `N`th parameter of the function-like object `F`
+  using type = typename ParamType<decltype(&F::operator()), N>::type;
 };
 
-/// FirstParamType specialization for a regular function or static method.
-template <typename R, typename Arg>
-struct FirstParamType<R (*)(Arg)> {
-  /// The type of the first parameter of the function
+/// ParamType specialization for a regular function or static method.
+template <typename R, int N, typename... Args>
+struct ParamType<R (*)(Args...), N> {
+  /// Arg is the raw type of the `N`th parameter of the function
+  using Arg = NthTypeOf<N, Args...>;
+  /// The type of the `N`th parameter of the function
   using type = typename std::decay<Arg>::type;
 };
 
-/// FirstParamType specialization for a non-static method.
-template <typename R, typename C, typename Arg>
-struct FirstParamType<R (C::*)(Arg)> {
-  /// The type of the first parameter of the function
+/// ParamType specialization for a non-static method.
+template <typename R, typename C, int N, typename... Args>
+struct ParamType<R (C::*)(Args...), N> {
+  /// Arg is the raw type of the `N`th parameter of the function
+  using Arg = NthTypeOf<N, Args...>;
+  /// The type of the `N`th parameter of the function
   using type = typename std::decay<Arg>::type;
 };
 
-/// FirstParamType specialization for a non-static, const method.
-template <typename R, typename C, typename Arg>
-struct FirstParamType<R (C::*)(Arg) const> {
-  /// The type of the first parameter of the function
+/// ParamType specialization for a non-static, const method.
+template <typename R, typename C, int N, typename... Args>
+struct ParamType<R (C::*)(Args...) const, N> {
+  /// Arg is the raw type of the `N`th parameter of the function
+  using Arg = NthTypeOf<N, Args...>;
+  /// The type of the `N`th parameter of the function
   using type = typename std::decay<Arg>::type;
 };
 
-/// FirstParamTypeT is an alias to `typename FirstParamType<F>::type`.
-template <typename F>
-using FirstParamTypeT = typename FirstParamType<F>::type;
+/// ParamTypeT is an alias to `typename ParamType<F, N>::type`.
+template <typename F, int N>
+using ParamTypeT = typename ParamType<F, N>::type;
 
 /// If T is a base of BASE then EnableIfIsType resolves to type T, otherwise an
 /// invalid type.
diff --git a/src/ast/traits_test.cc b/src/ast/traits_test.cc
index 05ccced..6621827 100644
--- a/src/ast/traits_test.cc
+++ b/src/ast/traits_test.cc
@@ -24,49 +24,80 @@
 
 namespace {
 struct S {};
-void F(S) {}
+void F1(S) {}
+void F3(int, S, float) {}
 }  // namespace
 
-TEST(FirstParamType, Function) {
-  F({});  // Avoid unused method warning
-  static_assert(std::is_same<FirstParamTypeT<decltype(&F)>, S>::value, "");
+TEST(ParamType, Function) {
+  F1({});        // Avoid unused method warning
+  F3(0, {}, 0);  // Avoid unused method warning
+  static_assert(std::is_same<ParamTypeT<decltype(&F1), 0>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&F3), 0>, int>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&F3), 1>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&F3), 2>, float>::value, "");
 }
 
-TEST(FirstParamType, Method) {
+TEST(ParamType, Method) {
   class C {
    public:
-    void f(S) {}
+    void F1(S) {}
+    void F3(int, S, float) {}
   };
-  C().f({});  // Avoid unused method warning
-  static_assert(std::is_same<FirstParamTypeT<decltype(&C::f)>, S>::value, "");
-}
-
-TEST(FirstParamType, ConstMethod) {
-  class C {
-   public:
-    void f(S) const {}
-  };
-  C().f({});  // Avoid unused method warning
-  static_assert(std::is_same<FirstParamTypeT<decltype(&C::f)>, S>::value, "");
-}
-
-TEST(FirstParamType, StaticMethod) {
-  class C {
-   public:
-    static void f(S) {}
-  };
-  C().f({});  // Avoid unused method warning
-  static_assert(std::is_same<FirstParamTypeT<decltype(&C::f)>, S>::value, "");
-}
-
-TEST(FirstParamType, FunctionLike) {
-  static_assert(std::is_same<FirstParamTypeT<std::function<void(S)>>, S>::value,
+  C().F1({});        // Avoid unused method warning
+  C().F3(0, {}, 0);  // Avoid unused method warning
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F1), 0>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 0>, int>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 1>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 2>, float>::value,
                 "");
 }
 
-TEST(FirstParamType, Lambda) {
-  auto l = [](S) {};
-  static_assert(std::is_same<FirstParamTypeT<decltype(l)>, S>::value, "");
+TEST(ParamType, ConstMethod) {
+  class C {
+   public:
+    void F1(S) const {}
+    void F3(int, S, float) const {}
+  };
+  C().F1({});        // Avoid unused method warning
+  C().F3(0, {}, 0);  // Avoid unused method warning
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F1), 0>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 0>, int>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 1>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 2>, float>::value,
+                "");
+}
+
+TEST(ParamType, StaticMethod) {
+  class C {
+   public:
+    static void F1(S) {}
+    static void F3(int, S, float) {}
+  };
+  C::F1({});        // Avoid unused method warning
+  C::F3(0, {}, 0);  // Avoid unused method warning
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F1), 0>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 0>, int>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 1>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(&C::F3), 2>, float>::value,
+                "");
+}
+
+TEST(ParamType, FunctionLike) {
+  using F1 = std::function<void(S)>;
+  using F3 = std::function<void(int, S, float)>;
+  static_assert(std::is_same<ParamTypeT<F1, 0>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<F3, 0>, int>::value, "");
+  static_assert(std::is_same<ParamTypeT<F3, 1>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<F3, 2>, float>::value, "");
+}
+
+TEST(ParamType, Lambda) {
+  auto l1 = [](S) {};
+  auto l3 = [](int, S, float) {};
+  static_assert(std::is_same<ParamTypeT<decltype(l1), 0>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(l3), 0>, int>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(l3), 1>, S>::value, "");
+  static_assert(std::is_same<ParamTypeT<decltype(l3), 2>, float>::value, "");
 }
 
 }  // namespace traits