castable: Make Switch() smarter about return types

Infer the return type by finding the common type across all cases.
Types that derive from CastableBase will automatically infer to
the common base class.

Change-Id: I2112ca1abae34e55396685e9ebf2da12f8a6e3fc
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/80320
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/castable.h b/src/tint/castable.h
index 51cad00..05144f7 100644
--- a/src/tint/castable.h
+++ b/src/tint/castable.h
@@ -607,7 +607,7 @@
     if (type->Is(&TypeInfo::Of<CaseType>())) {
       auto* ptr = static_cast<CaseType*>(object);
       if constexpr (kHasReturnType) {
-        *result = std::get<0>(cases)(ptr);
+        *result = static_cast<RETURN_TYPE>(std::get<0>(cases)(ptr));
       } else {
         std::get<0>(cases)(ptr);
       }
@@ -654,7 +654,8 @@
                                       traits::Slice<0, kDefaultIndex>(cases))) {
         // Nothing matched. Evaluate default case.
         if constexpr (kHasReturnType) {
-          *result = std::get<kDefaultIndex>(cases)({});
+          *result =
+              static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({}));
         } else {
           std::get<kDefaultIndex>(cases)({});
         }
@@ -667,7 +668,7 @@
     if constexpr (kHasDefaultCase) {
       // Evaluate default case.
       if constexpr (kHasReturnType) {
-        *result = std::get<kDefaultIndex>(cases)({});
+        *result = static_cast<RETURN_TYPE>(std::get<kDefaultIndex>(cases)({}));
       } else {
         std::get<kDefaultIndex>(cases)({});
       }
@@ -675,6 +676,81 @@
   }
 }
 
+/// Resolves to T if T is not nullptr_t, otherwise resolves to Ignore.
+template <typename T>
+using NullptrToIgnore =
+    std::conditional_t<std::is_same_v<T, std::nullptr_t>, Ignore, T>;
+
+/// Resolves to `const TYPE` if any of `CASE_RETURN_TYPES` are const or
+/// pointer-to-const, otherwise resolves to TYPE.
+template <typename TYPE, typename... CASE_RETURN_TYPES>
+using PropagateReturnConst = std::conditional_t<
+    // Are any of the pointer-stripped types const?
+    (std::is_const_v<std::remove_pointer_t<CASE_RETURN_TYPES>> || ...),
+    const TYPE,  // Yes: Apply const to TYPE
+    TYPE>;       // No:  Passthrough
+
+/// SwitchReturnTypeImpl is the implementation of SwitchReturnType
+template <bool IS_CASTABLE,
+          typename REQUESTED_TYPE,
+          typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl;
+
+/// SwitchReturnTypeImpl specialization for non-castable case types and an
+/// explicitly specified return type.
+template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl</*IS_CASTABLE*/ false,
+                            REQUESTED_TYPE,
+                            CASE_RETURN_TYPES...> {
+  /// Resolves to `REQUESTED_TYPE`
+  using type = REQUESTED_TYPE;
+};
+
+/// SwitchReturnTypeImpl specialization for non-castable case types and an
+/// inferred return type.
+template <typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl</*IS_CASTABLE*/ false,
+                            Infer,
+                            CASE_RETURN_TYPES...> {
+  /// Resolves to the common type for all the cases return types.
+  using type = std::common_type_t<CASE_RETURN_TYPES...>;
+};
+
+/// SwitchReturnTypeImpl specialization for castable case types and an
+/// explicitly specified return type.
+template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl</*IS_CASTABLE*/ true,
+                            REQUESTED_TYPE,
+                            CASE_RETURN_TYPES...> {
+ public:
+  /// Resolves to `const REQUESTED_TYPE*` or `REQUESTED_TYPE*`
+  using type = PropagateReturnConst<std::remove_pointer_t<REQUESTED_TYPE>,
+                                    CASE_RETURN_TYPES...>*;
+};
+
+/// SwitchReturnTypeImpl specialization for castable case types and an infered
+/// return type.
+template <typename... CASE_RETURN_TYPES>
+struct SwitchReturnTypeImpl</*IS_CASTABLE*/ true, Infer, CASE_RETURN_TYPES...> {
+ private:
+  using InferredType = CastableCommonBase<
+      detail::NullptrToIgnore<std::remove_pointer_t<CASE_RETURN_TYPES>>...>;
+
+ public:
+  /// `const T*` or `T*`, where T is the common base type for all the castable
+  /// case types.
+  using type = PropagateReturnConst<InferredType, CASE_RETURN_TYPES...>*;
+};
+
+/// Resolves to the return type for a Switch() with the requested return type
+/// `REQUESTED_TYPE` and case statement return types. If `REQUESTED_TYPE` is
+/// Infer then the return type will be inferred from the case return types.
+template <typename REQUESTED_TYPE, typename... CASE_RETURN_TYPES>
+using SwitchReturnType = typename SwitchReturnTypeImpl<
+    IsCastable<NullptrToIgnore<std::remove_pointer_t<CASE_RETURN_TYPES>>...>,
+    REQUESTED_TYPE,
+    CASE_RETURN_TYPES...>::type;
+
 }  // namespace detail
 
 /// Switch is used to dispatch one of the provided callback case handler
@@ -712,10 +788,12 @@
 /// @param cases the switch cases
 /// @return the value returned by the called case. If no cases matched, then the
 /// zero value for the consistent case type.
-template <typename T, typename... CASES>
+template <typename RETURN_TYPE = detail::Infer,
+          typename T = CastableBase,
+          typename... CASES>
 inline auto Switch(T* object, CASES&&... cases) {
-  using Cases = std::tuple<CASES...>;
-  using ReturnType = traits::ReturnType<std::tuple_element_t<0, Cases>>;
+  using ReturnType =
+      detail::SwitchReturnType<RETURN_TYPE, traits::ReturnType<CASES>...>;
   static constexpr bool kHasReturnType = !std::is_same_v<ReturnType, void>;
 
   if constexpr (kHasReturnType) {
diff --git a/src/tint/castable_test.cc b/src/tint/castable_test.cc
index e5698f9..7ed66cb 100644
--- a/src/tint/castable_test.cc
+++ b/src/tint/castable_test.cc
@@ -380,6 +380,321 @@
   }
 }
 
+TEST(Castable, SwitchReturnValueWithDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    const char* result = Switch(
+        frog.get(),                              //
+        [](Mammal*) { return "mammal"; },        //
+        [](Amphibian*) { return "amphibian"; },  //
+        [](Default) { return "unknown"; });
+    static_assert(std::is_same_v<decltype(result), const char*>);
+    EXPECT_EQ(std::string(result), "amphibian");
+  }
+  {
+    const char* result = Switch(
+        bear.get(),                              //
+        [](Mammal*) { return "mammal"; },        //
+        [](Amphibian*) { return "amphibian"; },  //
+        [](Default) { return "unknown"; });
+    static_assert(std::is_same_v<decltype(result), const char*>);
+    EXPECT_EQ(std::string(result), "mammal");
+  }
+  {
+    const char* result = Switch(
+        gecko.get(),                             //
+        [](Mammal*) { return "mammal"; },        //
+        [](Amphibian*) { return "amphibian"; },  //
+        [](Default) { return "unknown"; });
+    static_assert(std::is_same_v<decltype(result), const char*>);
+    EXPECT_EQ(std::string(result), "unknown");
+  }
+}
+
+TEST(Castable, SwitchReturnValueWithoutDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    const char* result = Switch(
+        frog.get(),                        //
+        [](Mammal*) { return "mammal"; },  //
+        [](Amphibian*) { return "amphibian"; });
+    static_assert(std::is_same_v<decltype(result), const char*>);
+    EXPECT_EQ(std::string(result), "amphibian");
+  }
+  {
+    const char* result = Switch(
+        bear.get(),                        //
+        [](Mammal*) { return "mammal"; },  //
+        [](Amphibian*) { return "amphibian"; });
+    static_assert(std::is_same_v<decltype(result), const char*>);
+    EXPECT_EQ(std::string(result), "mammal");
+  }
+  {
+    auto* result = Switch(
+        gecko.get(),                       //
+        [](Mammal*) { return "mammal"; },  //
+        [](Amphibian*) { return "amphibian"; });
+    static_assert(std::is_same_v<decltype(result), const char*>);
+    EXPECT_EQ(result, nullptr);
+  }
+}
+
+TEST(Castable, SwitchInferPODReturnTypeWithDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    auto result = Switch(
+        frog.get(),                       //
+        [](Mammal*) { return 1; },        //
+        [](Amphibian*) { return 2.0f; },  //
+        [](Default) { return 3.0; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 2.0);
+  }
+  {
+    auto result = Switch(
+        bear.get(),                       //
+        [](Mammal*) { return 1.0; },      //
+        [](Amphibian*) { return 2.0f; },  //
+        [](Default) { return 3; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 1.0);
+  }
+  {
+    auto result = Switch(
+        gecko.get(),                   //
+        [](Mammal*) { return 1.0f; },  //
+        [](Amphibian*) { return 2; },  //
+        [](Default) { return 3.0; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 3.0);
+  }
+}
+
+TEST(Castable, SwitchInferPODReturnTypeWithoutDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    auto result = Switch(
+        frog.get(),                 //
+        [](Mammal*) { return 1; },  //
+        [](Amphibian*) { return 2.0f; });
+    static_assert(std::is_same_v<decltype(result), float>);
+    EXPECT_EQ(result, 2.0f);
+  }
+  {
+    auto result = Switch(
+        bear.get(),                    //
+        [](Mammal*) { return 1.0f; },  //
+        [](Amphibian*) { return 2; });
+    static_assert(std::is_same_v<decltype(result), float>);
+    EXPECT_EQ(result, 1.0f);
+  }
+  {
+    auto result = Switch(
+        gecko.get(),                  //
+        [](Mammal*) { return 1.0; },  //
+        [](Amphibian*) { return 2.0f; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 0.0);
+  }
+}
+
+TEST(Castable, SwitchInferCastableReturnTypeWithDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    auto* result = Switch(
+        frog.get(),                          //
+        [](Mammal* p) { return p; },         //
+        [](Amphibian*) { return nullptr; },  //
+        [](Default) { return nullptr; });
+    static_assert(std::is_same_v<decltype(result), Mammal*>);
+    EXPECT_EQ(result, nullptr);
+  }
+  {
+    auto* result = Switch(
+        bear.get(),                   //
+        [](Mammal* p) { return p; },  //
+        [](Amphibian* p) { return const_cast<const Amphibian*>(p); },
+        [](Default) { return nullptr; });
+    static_assert(std::is_same_v<decltype(result), const Animal*>);
+    EXPECT_EQ(result, bear.get());
+  }
+  {
+    auto* result = Switch(
+        gecko.get(),                     //
+        [](Mammal* p) { return p; },     //
+        [](Amphibian* p) { return p; },  //
+        [](Default) -> CastableBase* { return nullptr; });
+    static_assert(std::is_same_v<decltype(result), CastableBase*>);
+    EXPECT_EQ(result, nullptr);
+  }
+}
+
+TEST(Castable, SwitchInferCastableReturnTypeWithoutDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    auto* result = Switch(
+        frog.get(),                   //
+        [](Mammal* p) { return p; },  //
+        [](Amphibian*) { return nullptr; });
+    static_assert(std::is_same_v<decltype(result), Mammal*>);
+    EXPECT_EQ(result, nullptr);
+  }
+  {
+    auto* result = Switch(
+        bear.get(),                                                     //
+        [](Mammal* p) { return p; },                                    //
+        [](Amphibian* p) { return const_cast<const Amphibian*>(p); });  //
+    static_assert(std::is_same_v<decltype(result), const Animal*>);
+    EXPECT_EQ(result, bear.get());
+  }
+  {
+    auto* result = Switch(
+        gecko.get(),                  //
+        [](Mammal* p) { return p; },  //
+        [](Amphibian* p) { return p; });
+    static_assert(std::is_same_v<decltype(result), Animal*>);
+    EXPECT_EQ(result, nullptr);
+  }
+}
+
+TEST(Castable, SwitchExplicitPODReturnTypeWithDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    auto result = Switch<double>(
+        frog.get(),                       //
+        [](Mammal*) { return 1; },        //
+        [](Amphibian*) { return 2.0f; },  //
+        [](Default) { return 3.0; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 2.0f);
+  }
+  {
+    auto result = Switch<double>(
+        bear.get(),                    //
+        [](Mammal*) { return 1; },     //
+        [](Amphibian*) { return 2; },  //
+        [](Default) { return 3; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 1.0f);
+  }
+  {
+    auto result = Switch<double>(
+        gecko.get(),                      //
+        [](Mammal*) { return 1.0f; },     //
+        [](Amphibian*) { return 2.0f; },  //
+        [](Default) { return 3.0f; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 3.0f);
+  }
+}
+
+TEST(Castable, SwitchExplicitPODReturnTypeWithoutDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    auto result = Switch<double>(
+        frog.get(),                 //
+        [](Mammal*) { return 1; },  //
+        [](Amphibian*) { return 2.0f; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 2.0f);
+  }
+  {
+    auto result = Switch<double>(
+        bear.get(),                    //
+        [](Mammal*) { return 1.0f; },  //
+        [](Amphibian*) { return 2; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 1.0f);
+  }
+  {
+    auto result = Switch<double>(
+        gecko.get(),                  //
+        [](Mammal*) { return 1.0; },  //
+        [](Amphibian*) { return 2.0f; });
+    static_assert(std::is_same_v<decltype(result), double>);
+    EXPECT_EQ(result, 0.0);
+  }
+}
+
+TEST(Castable, SwitchExplicitCastableReturnTypeWithDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    auto* result = Switch<Animal>(
+        frog.get(),                          //
+        [](Mammal* p) { return p; },         //
+        [](Amphibian*) { return nullptr; },  //
+        [](Default) { return nullptr; });
+    static_assert(std::is_same_v<decltype(result), Animal*>);
+    EXPECT_EQ(result, nullptr);
+  }
+  {
+    auto* result = Switch<CastableBase>(
+        bear.get(),                   //
+        [](Mammal* p) { return p; },  //
+        [](Amphibian* p) { return const_cast<const Amphibian*>(p); },
+        [](Default) { return nullptr; });
+    static_assert(std::is_same_v<decltype(result), const CastableBase*>);
+    EXPECT_EQ(result, bear.get());
+  }
+  {
+    auto* result = Switch<const Animal>(
+        gecko.get(),                     //
+        [](Mammal* p) { return p; },     //
+        [](Amphibian* p) { return p; },  //
+        [](Default) { return nullptr; });
+    static_assert(std::is_same_v<decltype(result), const Animal*>);
+    EXPECT_EQ(result, nullptr);
+  }
+}
+
+TEST(Castable, SwitchExplicitCastableReturnTypeWithoutDefault) {
+  std::unique_ptr<Animal> frog = std::make_unique<Frog>();
+  std::unique_ptr<Animal> bear = std::make_unique<Bear>();
+  std::unique_ptr<Animal> gecko = std::make_unique<Gecko>();
+  {
+    auto* result = Switch<Animal>(
+        frog.get(),                   //
+        [](Mammal* p) { return p; },  //
+        [](Amphibian*) { return nullptr; });
+    static_assert(std::is_same_v<decltype(result), Animal*>);
+    EXPECT_EQ(result, nullptr);
+  }
+  {
+    auto* result = Switch<CastableBase>(
+        bear.get(),                                                     //
+        [](Mammal* p) { return p; },                                    //
+        [](Amphibian* p) { return const_cast<const Amphibian*>(p); });  //
+    static_assert(std::is_same_v<decltype(result), const CastableBase*>);
+    EXPECT_EQ(result, bear.get());
+  }
+  {
+    auto* result = Switch<const Animal*>(
+        gecko.get(),                  //
+        [](Mammal* p) { return p; },  //
+        [](Amphibian* p) { return p; });
+    static_assert(std::is_same_v<decltype(result), const Animal*>);
+    EXPECT_EQ(result, nullptr);
+  }
+}
+
 TEST(Castable, SwitchNull) {
   Animal* null = nullptr;
   Switch(