tint/number: add Checked* overloads for f32 and f16

Also add missing unit tests for CheckedMul of floats.

Bug: tint:1581
Bug: tint:1747
Change-Id: I5d0d5d2b010803d6fd65f6feddc619cf1d071fe2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110170
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/number.h b/src/tint/number.h
index c116ddf..4fa6ed7 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -431,13 +431,14 @@
     return AInt(result);
 }
 
-/// @returns a + b, or an empty optional if the resulting value overflowed the AFloat
-inline std::optional<AFloat> CheckedAdd(AFloat a, AFloat b) {
-    auto result = a.value + b.value;
-    if (!std::isfinite(result)) {
+/// @returns a + b, or an empty optional if the resulting value overflowed the float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedAdd(FloatingPointT a, FloatingPointT b) {
+    auto result = FloatingPointT{a.value + b.value};
+    if (!std::isfinite(result.value)) {
         return {};
     }
-    return AFloat{result};
+    return result;
 }
 
 /// @returns a - b, or an empty optional if the resulting value overflowed the AInt
@@ -462,13 +463,14 @@
     return AInt(result);
 }
 
-/// @returns a + b, or an empty optional if the resulting value overflowed the AFloat
-inline std::optional<AFloat> CheckedSub(AFloat a, AFloat b) {
-    auto result = a.value - b.value;
-    if (!std::isfinite(result)) {
+/// @returns a + b, or an empty optional if the resulting value overflowed the float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedSub(FloatingPointT a, FloatingPointT b) {
+    auto result = FloatingPointT{a.value - b.value};
+    if (!std::isfinite(result.value)) {
         return {};
     }
-    return AFloat{result};
+    return result;
 }
 
 /// @returns a * b, or an empty optional if the resulting value overflowed the AInt
@@ -505,13 +507,14 @@
     return AInt(result);
 }
 
-/// @returns a * b, or an empty optional if the resulting value overflowed the AFloat
-inline std::optional<AFloat> CheckedMul(AFloat a, AFloat b) {
-    auto result = a.value * b.value;
-    if (!std::isfinite(result)) {
+/// @returns a * b, or an empty optional if the resulting value overflowed the float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedMul(FloatingPointT a, FloatingPointT b) {
+    auto result = FloatingPointT{a.value * b.value};
+    if (!std::isfinite(result.value)) {
         return {};
     }
-    return AFloat{result};
+    return result;
 }
 
 /// @returns a / b, or an empty optional if the resulting value overflowed the AInt
@@ -527,13 +530,14 @@
     return AInt{a.value / b.value};
 }
 
-/// @returns a / b, or an empty optional if the resulting value overflowed the AFloat
-inline std::optional<AFloat> CheckedDiv(AFloat a, AFloat b) {
-    auto result = a.value / b.value;
-    if (!std::isfinite(result)) {
+/// @returns a / b, or an empty optional if the resulting value overflowed the float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedDiv(FloatingPointT a, FloatingPointT b) {
+    auto result = FloatingPointT{a.value / b.value};
+    if (!std::isfinite(result.value)) {
         return {};
     }
-    return AFloat{result};
+    return result;
 }
 
 /// @returns a * b + c, or an empty optional if the value overflowed the AInt
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
index e6ebfe6..f21e531 100644
--- a/src/tint/number_test.cc
+++ b/src/tint/number_test.cc
@@ -14,6 +14,7 @@
 
 #include <cmath>
 #include <tuple>
+#include <variant>
 #include <vector>
 
 #include "src/tint/program_builder.h"
@@ -26,6 +27,15 @@
 namespace tint {
 namespace {
 
+// Concats any number of std::vectors
+template <typename Vec, typename... Vecs>
+[[nodiscard]] inline auto Concat(Vec&& v1, Vecs&&... vs) {
+    auto total_size = v1.size() + (vs.size() + ...);
+    v1.reserve(total_size);
+    (std::move(vs.begin(), vs.end(), std::back_inserter(v1)), ...);
+    return std::move(v1);
+}
+
 // Next ULP up from kHighestF32 for a float64.
 constexpr double kHighestF32NextULP = 0x1.fffffe0000001p+127;
 
@@ -378,16 +388,23 @@
 #define OVERFLOW \
     {}
 
-using BinaryCheckedCase_AInt = std::tuple<std::optional<AInt>, AInt, AInt>;
-using BinaryCheckedCase_AFloat = std::tuple<std::optional<AFloat>, AFloat, AFloat>;
+template <typename T>
+auto Overflow = std::optional<T>{};
 
+using BinaryCheckedCase_AInt = std::tuple<std::optional<AInt>, AInt, AInt>;
 using CheckedAddTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
+
+using FloatInputTypes = std::variant<AFloat, f32, f16>;
+using FloatExpectedTypes =
+    std::variant<std::optional<AFloat>, std::optional<f32>, std::optional<f16>>;
+using BinaryCheckedCase_Float = std::tuple<FloatExpectedTypes, FloatInputTypes, FloatInputTypes>;
+
 TEST_P(CheckedAddTest_AInt, Test) {
     auto expect = std::get<0>(GetParam());
     auto a = std::get<1>(GetParam());
     auto b = std::get<2>(GetParam());
-    EXPECT_EQ(CheckedAdd(a, b), expect) << std::hex << "0x" << a << " + 0x" << b;
-    EXPECT_EQ(CheckedAdd(b, a), expect) << std::hex << "0x" << a << " + 0x" << b;
+    EXPECT_TRUE(CheckedAdd(a, b) == expect) << std::hex << "0x" << a << " + 0x" << b;
+    EXPECT_TRUE(CheckedAdd(b, a) == expect) << std::hex << "0x" << a << " + 0x" << b;
 }
 INSTANTIATE_TEST_SUITE_P(
     CheckedAddTest_AInt,
@@ -417,41 +434,50 @@
         ////////////////////////////////////////////////////////////////////////
     }));
 
-using CheckedAddTest_AFloat = testing::TestWithParam<BinaryCheckedCase_AFloat>;
-TEST_P(CheckedAddTest_AFloat, Test) {
-    auto expect = std::get<0>(GetParam());
-    auto a = std::get<1>(GetParam());
-    auto b = std::get<2>(GetParam());
-    EXPECT_EQ(CheckedAdd(a, b), expect) << std::hex << "0x" << a << " + 0x" << b;
-    EXPECT_EQ(CheckedAdd(b, a), expect) << std::hex << "0x" << a << " + 0x" << b;
+using CheckedAddTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedAddTest_Float, Test) {
+    auto& p = GetParam();
+    std::visit(
+        [&](auto&& lhs) {
+            using T = std::decay_t<decltype(lhs)>;
+            auto rhs = std::get<T>(std::get<2>(p));
+            auto expect = std::get<std::optional<T>>(std::get<0>(p));
+            EXPECT_TRUE(CheckedAdd(lhs, rhs) == expect)
+                << std::hex << "0x" << lhs << " + 0x" << rhs;
+            EXPECT_TRUE(CheckedAdd(rhs, lhs) == expect)
+                << std::hex << "0x" << lhs << " + 0x" << rhs;
+        },
+        std::get<1>(p));
 }
-INSTANTIATE_TEST_SUITE_P(
-    CheckedAddTest_AFloat,
-    CheckedAddTest_AFloat,
-    testing::ValuesIn(std::vector<BinaryCheckedCase_AFloat>{
-        {AFloat(0), AFloat(0), AFloat(0)},
-        {AFloat(1), AFloat(1), AFloat(0)},
-        {AFloat(2), AFloat(1), AFloat(1)},
-        {AFloat(0), AFloat(-1), AFloat(1)},
-        {AFloat(3), AFloat(2), AFloat(1)},
-        {AFloat(-1), AFloat(-2), AFloat(1)},
-        {AFloat(0x300), AFloat(0x100), AFloat(0x200)},
-        {AFloat(0x100), AFloat(-0x100), AFloat(0x200)},
-        {AFloat::Highest(), AFloat(1), AFloat(AFloat::kHighestValue - 1)},
-        {AFloat::Lowest(), AFloat(-1), AFloat(AFloat::kLowestValue + 1)},
-        {AFloat::Highest(), AFloat::Highest(), AFloat(0)},
-        {AFloat::Lowest(), AFloat::Lowest(), AFloat(0)},
-        {OVERFLOW, AFloat::Highest(), AFloat::Highest()},
-        {OVERFLOW, AFloat::Lowest(), AFloat::Lowest()},
-        ////////////////////////////////////////////////////////////////////////
-    }));
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedAddTest_FloatCases() {
+    return {
+        {T(0), T(0), T(0)},
+        {T(1), T(1), T(0)},
+        {T(2), T(1), T(1)},
+        {T(0), T(-1), T(1)},
+        {T(3), T(2), T(1)},
+        {T(-1), T(-2), T(1)},
+        {T(0x300), T(0x100), T(0x200)},
+        {T(0x100), T(-0x100), T(0x200)},
+        {T::Highest(), T::Highest(), T(0)},
+        {T::Lowest(), T::Lowest(), T(0)},
+        {Overflow<T>, T::Highest(), T::Highest()},
+        {Overflow<T>, T::Lowest(), T::Lowest()},
+    };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedAddTest_Float,
+                         CheckedAddTest_Float,
+                         testing::ValuesIn(Concat(CheckedAddTest_FloatCases<AFloat>(),
+                                                  CheckedAddTest_FloatCases<f32>(),
+                                                  CheckedAddTest_FloatCases<f16>())));
 
 using CheckedSubTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
 TEST_P(CheckedSubTest_AInt, Test) {
     auto expect = std::get<0>(GetParam());
     auto a = std::get<1>(GetParam());
     auto b = std::get<2>(GetParam());
-    EXPECT_EQ(CheckedSub(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+    EXPECT_TRUE(CheckedSub(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
 }
 INSTANTIATE_TEST_SUITE_P(
     CheckedSubTest_AInt,
@@ -480,40 +506,48 @@
         ////////////////////////////////////////////////////////////////////////
     }));
 
-using CheckedSubTest_AFloat = testing::TestWithParam<BinaryCheckedCase_AFloat>;
-TEST_P(CheckedSubTest_AFloat, Test) {
-    auto expect = std::get<0>(GetParam());
-    auto a = std::get<1>(GetParam());
-    auto b = std::get<2>(GetParam());
-    EXPECT_EQ(CheckedSub(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+using CheckedSubTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedSubTest_Float, Test) {
+    auto& p = GetParam();
+    std::visit(
+        [&](auto&& lhs) {
+            using T = std::decay_t<decltype(lhs)>;
+            auto rhs = std::get<T>(std::get<2>(p));
+            auto expect = std::get<std::optional<T>>(std::get<0>(p));
+            EXPECT_TRUE(CheckedSub(lhs, rhs) == expect)
+                << std::hex << "0x" << lhs << " - 0x" << rhs;
+        },
+        std::get<1>(p));
 }
-INSTANTIATE_TEST_SUITE_P(
-    CheckedSubTest_AFloat,
-    CheckedSubTest_AFloat,
-    testing::ValuesIn(std::vector<BinaryCheckedCase_AFloat>{
-        {AFloat(0), AFloat(0), AFloat(0)},
-        {AFloat(1), AFloat(1), AFloat(0)},
-        {AFloat(0), AFloat(1), AFloat(1)},
-        {AFloat(-2), AFloat(-1), AFloat(1)},
-        {AFloat(1), AFloat(2), AFloat(1)},
-        {AFloat(-3), AFloat(-2), AFloat(1)},
-        {AFloat(0x100), AFloat(0x300), AFloat(0x200)},
-        {AFloat(-0x300), AFloat(-0x100), AFloat(0x200)},
-        {AFloat::Highest(), AFloat(AFloat::kHighestValue - 1), AFloat(-1)},
-        {AFloat::Lowest(), AFloat(AFloat::kLowestValue + 1), AFloat(1)},
-        {AFloat::Highest(), AFloat::Highest(), AFloat(0)},
-        {AFloat::Lowest(), AFloat::Lowest(), AFloat(0)},
-        {OVERFLOW, AFloat::Lowest(), AFloat::Highest()},
-        ////////////////////////////////////////////////////////////////////////
-    }));
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedSubTest_FloatCases() {
+    return {
+        {T(0), T(0), T(0)},
+        {T(1), T(1), T(0)},
+        {T(0), T(1), T(1)},
+        {T(-2), T(-1), T(1)},
+        {T(1), T(2), T(1)},
+        {T(-3), T(-2), T(1)},
+        {T(0x100), T(0x300), T(0x200)},
+        {T(-0x300), T(-0x100), T(0x200)},
+        {T::Highest(), T::Highest(), T(0)},
+        {T::Lowest(), T::Lowest(), T(0)},
+        {Overflow<T>, T::Lowest(), T::Highest()},
+    };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedSubTest_Float,
+                         CheckedSubTest_Float,
+                         testing::ValuesIn(Concat(CheckedSubTest_FloatCases<AFloat>(),
+                                                  CheckedSubTest_FloatCases<f32>(),
+                                                  CheckedSubTest_FloatCases<f16>())));
 
 using CheckedMulTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
 TEST_P(CheckedMulTest_AInt, Test) {
     auto expect = std::get<0>(GetParam());
     auto a = std::get<1>(GetParam());
     auto b = std::get<2>(GetParam());
-    EXPECT_EQ(CheckedMul(a, b), expect) << std::hex << "0x" << a << " * 0x" << b;
-    EXPECT_EQ(CheckedMul(b, a), expect) << std::hex << "0x" << a << " * 0x" << b;
+    EXPECT_TRUE(CheckedMul(a, b) == expect) << std::hex << "0x" << a << " * 0x" << b;
+    EXPECT_TRUE(CheckedMul(b, a) == expect) << std::hex << "0x" << a << " * 0x" << b;
 }
 INSTANTIATE_TEST_SUITE_P(
     CheckedMulTest_AInt,
@@ -553,12 +587,48 @@
         ////////////////////////////////////////////////////////////////////////
     }));
 
+using CheckedMulTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedMulTest_Float, Test) {
+    auto& p = GetParam();
+    std::visit(
+        [&](auto&& lhs) {
+            using T = std::decay_t<decltype(lhs)>;
+            auto rhs = std::get<T>(std::get<2>(p));
+            auto expect = std::get<std::optional<T>>(std::get<0>(p));
+            EXPECT_TRUE(CheckedMul(lhs, rhs) == expect)
+                << std::hex << "0x" << lhs << " * 0x" << rhs;
+            EXPECT_TRUE(CheckedMul(rhs, lhs) == expect)
+                << std::hex << "0x" << lhs << " * 0x" << rhs;
+        },
+        std::get<1>(p));
+}
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedMulTest_FloatCases() {
+    return {
+        {T(0), T(0), T(0)},
+        {T(0), T(1), T(0)},
+        {T(1), T(1), T(1)},
+        {T(-1), T(-1), T(1)},
+        {T(2), T(2), T(1)},
+        {T(-2), T(-2), T(1)},
+        {T(0), T::Highest(), T(0)},
+        {T(0), T::Lowest(), -T(0)},
+        {Overflow<T>, T::Highest(), T::Highest()},
+        {Overflow<T>, T::Lowest(), T::Lowest()},
+    };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedMulTest_Float,
+                         CheckedMulTest_Float,
+                         testing::ValuesIn(Concat(CheckedMulTest_FloatCases<AFloat>(),
+                                                  CheckedMulTest_FloatCases<f32>(),
+                                                  CheckedMulTest_FloatCases<f16>())));
+
 using CheckedDivTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
 TEST_P(CheckedDivTest_AInt, Test) {
     auto expect = std::get<0>(GetParam());
     auto a = std::get<1>(GetParam());
     auto b = std::get<2>(GetParam());
-    EXPECT_EQ(CheckedDiv(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+    EXPECT_TRUE(CheckedDiv(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
 }
 INSTANTIATE_TEST_SUITE_P(
     CheckedDivTest_AInt,
@@ -579,33 +649,43 @@
         ////////////////////////////////////////////////////////////////////////
     }));
 
-using CheckedDivTest_AFloat = testing::TestWithParam<BinaryCheckedCase_AFloat>;
-TEST_P(CheckedDivTest_AFloat, Test) {
-    auto expect = std::get<0>(GetParam());
-    auto a = std::get<1>(GetParam());
-    auto b = std::get<2>(GetParam());
-    EXPECT_EQ(CheckedDiv(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+using CheckedDivTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedDivTest_Float, Test) {
+    auto& p = GetParam();
+    std::visit(
+        [&](auto&& lhs) {
+            using T = std::decay_t<decltype(lhs)>;
+            auto rhs = std::get<T>(std::get<2>(p));
+            auto expect = std::get<std::optional<T>>(std::get<0>(p));
+            EXPECT_TRUE(CheckedDiv(lhs, rhs) == expect)
+                << std::hex << "0x" << lhs << " / 0x" << rhs;
+        },
+        std::get<1>(p));
 }
-INSTANTIATE_TEST_SUITE_P(
-    CheckedDivTest_AFloat,
-    CheckedDivTest_AFloat,
-    testing::ValuesIn(std::vector<BinaryCheckedCase_AFloat>{
-        {AFloat(0), AFloat(0), AFloat(1)},
-        {AFloat(1), AFloat(1), AFloat(1)},
-        {AFloat(1), AFloat(1), AFloat(1)},
-        {AFloat(2), AFloat(2), AFloat(1)},
-        {AFloat(2), AFloat(4), AFloat(2)},
-        {AFloat::Highest(), AFloat::Highest(), AFloat(1)},
-        {AFloat::Lowest(), AFloat::Lowest(), AFloat(1)},
-        {AFloat(1), AFloat::Highest(), AFloat::Highest()},
-        {AFloat(0), AFloat(0), AFloat::Highest()},
-        {-AFloat(0), AFloat(0), AFloat::Lowest()},
-        {OVERFLOW, AFloat(123), AFloat(0)},
-        {OVERFLOW, AFloat(123), AFloat(-0)},
-        {OVERFLOW, AFloat(-123), AFloat(0)},
-        {OVERFLOW, AFloat(-123), AFloat(-0)},
-        ////////////////////////////////////////////////////////////////////////
-    }));
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedDivTest_FloatCases() {
+    return {
+        {T(0), T(0), T(1)},
+        {T(1), T(1), T(1)},
+        {T(1), T(1), T(1)},
+        {T(2), T(2), T(1)},
+        {T(2), T(4), T(2)},
+        {T::Highest(), T::Highest(), T(1)},
+        {T::Lowest(), T::Lowest(), T(1)},
+        {T(1), T::Highest(), T::Highest()},
+        {T(0), T(0), T::Highest()},
+        {-T(0), T(0), T::Lowest()},
+        {Overflow<T>, T(123), T(0)},
+        {Overflow<T>, T(123), T(-0)},
+        {Overflow<T>, T(-123), T(0)},
+        {Overflow<T>, T(-123), T(-0)},
+    };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedDivTest_Float,
+                         CheckedDivTest_Float,
+                         testing::ValuesIn(Concat(CheckedDivTest_FloatCases<AFloat>(),
+                                                  CheckedDivTest_FloatCases<f32>(),
+                                                  CheckedDivTest_FloatCases<f16>())));
 
 using TernaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt, AInt>;