tint/number: add Checked* overloads for f32 and f16
Also add missing unit tests for CheckedMul of floats.
Bug: tint:1581
Bug: tint:1747
Change-Id: I5d0d5d2b010803d6fd65f6feddc619cf1d071fe2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110170
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/number.h b/src/tint/number.h
index c116ddf..4fa6ed7 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -431,13 +431,14 @@
return AInt(result);
}
-/// @returns a + b, or an empty optional if the resulting value overflowed the AFloat
-inline std::optional<AFloat> CheckedAdd(AFloat a, AFloat b) {
- auto result = a.value + b.value;
- if (!std::isfinite(result)) {
+/// @returns a + b, or an empty optional if the resulting value overflowed the float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedAdd(FloatingPointT a, FloatingPointT b) {
+ auto result = FloatingPointT{a.value + b.value};
+ if (!std::isfinite(result.value)) {
return {};
}
- return AFloat{result};
+ return result;
}
/// @returns a - b, or an empty optional if the resulting value overflowed the AInt
@@ -462,13 +463,14 @@
return AInt(result);
}
-/// @returns a + b, or an empty optional if the resulting value overflowed the AFloat
-inline std::optional<AFloat> CheckedSub(AFloat a, AFloat b) {
- auto result = a.value - b.value;
- if (!std::isfinite(result)) {
+/// @returns a + b, or an empty optional if the resulting value overflowed the float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedSub(FloatingPointT a, FloatingPointT b) {
+ auto result = FloatingPointT{a.value - b.value};
+ if (!std::isfinite(result.value)) {
return {};
}
- return AFloat{result};
+ return result;
}
/// @returns a * b, or an empty optional if the resulting value overflowed the AInt
@@ -505,13 +507,14 @@
return AInt(result);
}
-/// @returns a * b, or an empty optional if the resulting value overflowed the AFloat
-inline std::optional<AFloat> CheckedMul(AFloat a, AFloat b) {
- auto result = a.value * b.value;
- if (!std::isfinite(result)) {
+/// @returns a * b, or an empty optional if the resulting value overflowed the float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedMul(FloatingPointT a, FloatingPointT b) {
+ auto result = FloatingPointT{a.value * b.value};
+ if (!std::isfinite(result.value)) {
return {};
}
- return AFloat{result};
+ return result;
}
/// @returns a / b, or an empty optional if the resulting value overflowed the AInt
@@ -527,13 +530,14 @@
return AInt{a.value / b.value};
}
-/// @returns a / b, or an empty optional if the resulting value overflowed the AFloat
-inline std::optional<AFloat> CheckedDiv(AFloat a, AFloat b) {
- auto result = a.value / b.value;
- if (!std::isfinite(result)) {
+/// @returns a / b, or an empty optional if the resulting value overflowed the float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedDiv(FloatingPointT a, FloatingPointT b) {
+ auto result = FloatingPointT{a.value / b.value};
+ if (!std::isfinite(result.value)) {
return {};
}
- return AFloat{result};
+ return result;
}
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
index e6ebfe6..f21e531 100644
--- a/src/tint/number_test.cc
+++ b/src/tint/number_test.cc
@@ -14,6 +14,7 @@
#include <cmath>
#include <tuple>
+#include <variant>
#include <vector>
#include "src/tint/program_builder.h"
@@ -26,6 +27,15 @@
namespace tint {
namespace {
+// Concats any number of std::vectors
+template <typename Vec, typename... Vecs>
+[[nodiscard]] inline auto Concat(Vec&& v1, Vecs&&... vs) {
+ auto total_size = v1.size() + (vs.size() + ...);
+ v1.reserve(total_size);
+ (std::move(vs.begin(), vs.end(), std::back_inserter(v1)), ...);
+ return std::move(v1);
+}
+
// Next ULP up from kHighestF32 for a float64.
constexpr double kHighestF32NextULP = 0x1.fffffe0000001p+127;
@@ -378,16 +388,23 @@
#define OVERFLOW \
{}
-using BinaryCheckedCase_AInt = std::tuple<std::optional<AInt>, AInt, AInt>;
-using BinaryCheckedCase_AFloat = std::tuple<std::optional<AFloat>, AFloat, AFloat>;
+template <typename T>
+auto Overflow = std::optional<T>{};
+using BinaryCheckedCase_AInt = std::tuple<std::optional<AInt>, AInt, AInt>;
using CheckedAddTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
+
+using FloatInputTypes = std::variant<AFloat, f32, f16>;
+using FloatExpectedTypes =
+ std::variant<std::optional<AFloat>, std::optional<f32>, std::optional<f16>>;
+using BinaryCheckedCase_Float = std::tuple<FloatExpectedTypes, FloatInputTypes, FloatInputTypes>;
+
TEST_P(CheckedAddTest_AInt, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_EQ(CheckedAdd(a, b), expect) << std::hex << "0x" << a << " + 0x" << b;
- EXPECT_EQ(CheckedAdd(b, a), expect) << std::hex << "0x" << a << " + 0x" << b;
+ EXPECT_TRUE(CheckedAdd(a, b) == expect) << std::hex << "0x" << a << " + 0x" << b;
+ EXPECT_TRUE(CheckedAdd(b, a) == expect) << std::hex << "0x" << a << " + 0x" << b;
}
INSTANTIATE_TEST_SUITE_P(
CheckedAddTest_AInt,
@@ -417,41 +434,50 @@
////////////////////////////////////////////////////////////////////////
}));
-using CheckedAddTest_AFloat = testing::TestWithParam<BinaryCheckedCase_AFloat>;
-TEST_P(CheckedAddTest_AFloat, Test) {
- auto expect = std::get<0>(GetParam());
- auto a = std::get<1>(GetParam());
- auto b = std::get<2>(GetParam());
- EXPECT_EQ(CheckedAdd(a, b), expect) << std::hex << "0x" << a << " + 0x" << b;
- EXPECT_EQ(CheckedAdd(b, a), expect) << std::hex << "0x" << a << " + 0x" << b;
+using CheckedAddTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedAddTest_Float, Test) {
+ auto& p = GetParam();
+ std::visit(
+ [&](auto&& lhs) {
+ using T = std::decay_t<decltype(lhs)>;
+ auto rhs = std::get<T>(std::get<2>(p));
+ auto expect = std::get<std::optional<T>>(std::get<0>(p));
+ EXPECT_TRUE(CheckedAdd(lhs, rhs) == expect)
+ << std::hex << "0x" << lhs << " + 0x" << rhs;
+ EXPECT_TRUE(CheckedAdd(rhs, lhs) == expect)
+ << std::hex << "0x" << lhs << " + 0x" << rhs;
+ },
+ std::get<1>(p));
}
-INSTANTIATE_TEST_SUITE_P(
- CheckedAddTest_AFloat,
- CheckedAddTest_AFloat,
- testing::ValuesIn(std::vector<BinaryCheckedCase_AFloat>{
- {AFloat(0), AFloat(0), AFloat(0)},
- {AFloat(1), AFloat(1), AFloat(0)},
- {AFloat(2), AFloat(1), AFloat(1)},
- {AFloat(0), AFloat(-1), AFloat(1)},
- {AFloat(3), AFloat(2), AFloat(1)},
- {AFloat(-1), AFloat(-2), AFloat(1)},
- {AFloat(0x300), AFloat(0x100), AFloat(0x200)},
- {AFloat(0x100), AFloat(-0x100), AFloat(0x200)},
- {AFloat::Highest(), AFloat(1), AFloat(AFloat::kHighestValue - 1)},
- {AFloat::Lowest(), AFloat(-1), AFloat(AFloat::kLowestValue + 1)},
- {AFloat::Highest(), AFloat::Highest(), AFloat(0)},
- {AFloat::Lowest(), AFloat::Lowest(), AFloat(0)},
- {OVERFLOW, AFloat::Highest(), AFloat::Highest()},
- {OVERFLOW, AFloat::Lowest(), AFloat::Lowest()},
- ////////////////////////////////////////////////////////////////////////
- }));
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedAddTest_FloatCases() {
+ return {
+ {T(0), T(0), T(0)},
+ {T(1), T(1), T(0)},
+ {T(2), T(1), T(1)},
+ {T(0), T(-1), T(1)},
+ {T(3), T(2), T(1)},
+ {T(-1), T(-2), T(1)},
+ {T(0x300), T(0x100), T(0x200)},
+ {T(0x100), T(-0x100), T(0x200)},
+ {T::Highest(), T::Highest(), T(0)},
+ {T::Lowest(), T::Lowest(), T(0)},
+ {Overflow<T>, T::Highest(), T::Highest()},
+ {Overflow<T>, T::Lowest(), T::Lowest()},
+ };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedAddTest_Float,
+ CheckedAddTest_Float,
+ testing::ValuesIn(Concat(CheckedAddTest_FloatCases<AFloat>(),
+ CheckedAddTest_FloatCases<f32>(),
+ CheckedAddTest_FloatCases<f16>())));
using CheckedSubTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
TEST_P(CheckedSubTest_AInt, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_EQ(CheckedSub(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+ EXPECT_TRUE(CheckedSub(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
}
INSTANTIATE_TEST_SUITE_P(
CheckedSubTest_AInt,
@@ -480,40 +506,48 @@
////////////////////////////////////////////////////////////////////////
}));
-using CheckedSubTest_AFloat = testing::TestWithParam<BinaryCheckedCase_AFloat>;
-TEST_P(CheckedSubTest_AFloat, Test) {
- auto expect = std::get<0>(GetParam());
- auto a = std::get<1>(GetParam());
- auto b = std::get<2>(GetParam());
- EXPECT_EQ(CheckedSub(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+using CheckedSubTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedSubTest_Float, Test) {
+ auto& p = GetParam();
+ std::visit(
+ [&](auto&& lhs) {
+ using T = std::decay_t<decltype(lhs)>;
+ auto rhs = std::get<T>(std::get<2>(p));
+ auto expect = std::get<std::optional<T>>(std::get<0>(p));
+ EXPECT_TRUE(CheckedSub(lhs, rhs) == expect)
+ << std::hex << "0x" << lhs << " - 0x" << rhs;
+ },
+ std::get<1>(p));
}
-INSTANTIATE_TEST_SUITE_P(
- CheckedSubTest_AFloat,
- CheckedSubTest_AFloat,
- testing::ValuesIn(std::vector<BinaryCheckedCase_AFloat>{
- {AFloat(0), AFloat(0), AFloat(0)},
- {AFloat(1), AFloat(1), AFloat(0)},
- {AFloat(0), AFloat(1), AFloat(1)},
- {AFloat(-2), AFloat(-1), AFloat(1)},
- {AFloat(1), AFloat(2), AFloat(1)},
- {AFloat(-3), AFloat(-2), AFloat(1)},
- {AFloat(0x100), AFloat(0x300), AFloat(0x200)},
- {AFloat(-0x300), AFloat(-0x100), AFloat(0x200)},
- {AFloat::Highest(), AFloat(AFloat::kHighestValue - 1), AFloat(-1)},
- {AFloat::Lowest(), AFloat(AFloat::kLowestValue + 1), AFloat(1)},
- {AFloat::Highest(), AFloat::Highest(), AFloat(0)},
- {AFloat::Lowest(), AFloat::Lowest(), AFloat(0)},
- {OVERFLOW, AFloat::Lowest(), AFloat::Highest()},
- ////////////////////////////////////////////////////////////////////////
- }));
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedSubTest_FloatCases() {
+ return {
+ {T(0), T(0), T(0)},
+ {T(1), T(1), T(0)},
+ {T(0), T(1), T(1)},
+ {T(-2), T(-1), T(1)},
+ {T(1), T(2), T(1)},
+ {T(-3), T(-2), T(1)},
+ {T(0x100), T(0x300), T(0x200)},
+ {T(-0x300), T(-0x100), T(0x200)},
+ {T::Highest(), T::Highest(), T(0)},
+ {T::Lowest(), T::Lowest(), T(0)},
+ {Overflow<T>, T::Lowest(), T::Highest()},
+ };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedSubTest_Float,
+ CheckedSubTest_Float,
+ testing::ValuesIn(Concat(CheckedSubTest_FloatCases<AFloat>(),
+ CheckedSubTest_FloatCases<f32>(),
+ CheckedSubTest_FloatCases<f16>())));
using CheckedMulTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
TEST_P(CheckedMulTest_AInt, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_EQ(CheckedMul(a, b), expect) << std::hex << "0x" << a << " * 0x" << b;
- EXPECT_EQ(CheckedMul(b, a), expect) << std::hex << "0x" << a << " * 0x" << b;
+ EXPECT_TRUE(CheckedMul(a, b) == expect) << std::hex << "0x" << a << " * 0x" << b;
+ EXPECT_TRUE(CheckedMul(b, a) == expect) << std::hex << "0x" << a << " * 0x" << b;
}
INSTANTIATE_TEST_SUITE_P(
CheckedMulTest_AInt,
@@ -553,12 +587,48 @@
////////////////////////////////////////////////////////////////////////
}));
+using CheckedMulTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedMulTest_Float, Test) {
+ auto& p = GetParam();
+ std::visit(
+ [&](auto&& lhs) {
+ using T = std::decay_t<decltype(lhs)>;
+ auto rhs = std::get<T>(std::get<2>(p));
+ auto expect = std::get<std::optional<T>>(std::get<0>(p));
+ EXPECT_TRUE(CheckedMul(lhs, rhs) == expect)
+ << std::hex << "0x" << lhs << " * 0x" << rhs;
+ EXPECT_TRUE(CheckedMul(rhs, lhs) == expect)
+ << std::hex << "0x" << lhs << " * 0x" << rhs;
+ },
+ std::get<1>(p));
+}
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedMulTest_FloatCases() {
+ return {
+ {T(0), T(0), T(0)},
+ {T(0), T(1), T(0)},
+ {T(1), T(1), T(1)},
+ {T(-1), T(-1), T(1)},
+ {T(2), T(2), T(1)},
+ {T(-2), T(-2), T(1)},
+ {T(0), T::Highest(), T(0)},
+ {T(0), T::Lowest(), -T(0)},
+ {Overflow<T>, T::Highest(), T::Highest()},
+ {Overflow<T>, T::Lowest(), T::Lowest()},
+ };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedMulTest_Float,
+ CheckedMulTest_Float,
+ testing::ValuesIn(Concat(CheckedMulTest_FloatCases<AFloat>(),
+ CheckedMulTest_FloatCases<f32>(),
+ CheckedMulTest_FloatCases<f16>())));
+
using CheckedDivTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
TEST_P(CheckedDivTest_AInt, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_EQ(CheckedDiv(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+ EXPECT_TRUE(CheckedDiv(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
}
INSTANTIATE_TEST_SUITE_P(
CheckedDivTest_AInt,
@@ -579,33 +649,43 @@
////////////////////////////////////////////////////////////////////////
}));
-using CheckedDivTest_AFloat = testing::TestWithParam<BinaryCheckedCase_AFloat>;
-TEST_P(CheckedDivTest_AFloat, Test) {
- auto expect = std::get<0>(GetParam());
- auto a = std::get<1>(GetParam());
- auto b = std::get<2>(GetParam());
- EXPECT_EQ(CheckedDiv(a, b), expect) << std::hex << "0x" << a << " - 0x" << b;
+using CheckedDivTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedDivTest_Float, Test) {
+ auto& p = GetParam();
+ std::visit(
+ [&](auto&& lhs) {
+ using T = std::decay_t<decltype(lhs)>;
+ auto rhs = std::get<T>(std::get<2>(p));
+ auto expect = std::get<std::optional<T>>(std::get<0>(p));
+ EXPECT_TRUE(CheckedDiv(lhs, rhs) == expect)
+ << std::hex << "0x" << lhs << " / 0x" << rhs;
+ },
+ std::get<1>(p));
}
-INSTANTIATE_TEST_SUITE_P(
- CheckedDivTest_AFloat,
- CheckedDivTest_AFloat,
- testing::ValuesIn(std::vector<BinaryCheckedCase_AFloat>{
- {AFloat(0), AFloat(0), AFloat(1)},
- {AFloat(1), AFloat(1), AFloat(1)},
- {AFloat(1), AFloat(1), AFloat(1)},
- {AFloat(2), AFloat(2), AFloat(1)},
- {AFloat(2), AFloat(4), AFloat(2)},
- {AFloat::Highest(), AFloat::Highest(), AFloat(1)},
- {AFloat::Lowest(), AFloat::Lowest(), AFloat(1)},
- {AFloat(1), AFloat::Highest(), AFloat::Highest()},
- {AFloat(0), AFloat(0), AFloat::Highest()},
- {-AFloat(0), AFloat(0), AFloat::Lowest()},
- {OVERFLOW, AFloat(123), AFloat(0)},
- {OVERFLOW, AFloat(123), AFloat(-0)},
- {OVERFLOW, AFloat(-123), AFloat(0)},
- {OVERFLOW, AFloat(-123), AFloat(-0)},
- ////////////////////////////////////////////////////////////////////////
- }));
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedDivTest_FloatCases() {
+ return {
+ {T(0), T(0), T(1)},
+ {T(1), T(1), T(1)},
+ {T(1), T(1), T(1)},
+ {T(2), T(2), T(1)},
+ {T(2), T(4), T(2)},
+ {T::Highest(), T::Highest(), T(1)},
+ {T::Lowest(), T::Lowest(), T(1)},
+ {T(1), T::Highest(), T::Highest()},
+ {T(0), T(0), T::Highest()},
+ {-T(0), T(0), T::Lowest()},
+ {Overflow<T>, T(123), T(0)},
+ {Overflow<T>, T(123), T(-0)},
+ {Overflow<T>, T(-123), T(0)},
+ {Overflow<T>, T(-123), T(-0)},
+ };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedDivTest_Float,
+ CheckedDivTest_Float,
+ testing::ValuesIn(Concat(CheckedDivTest_FloatCases<AFloat>(),
+ CheckedDivTest_FloatCases<f32>(),
+ CheckedDivTest_FloatCases<f16>())));
using TernaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt, AInt>;