tint: add CheckedPow
Also improve test validation so that failed tests emit the two values
being compared.
Bug: tint:1581
Change-Id: Ie6f62cb623cf6f50a85ac3229f0968321e45154f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/113820
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/number.h b/src/tint/number.h
index 975ce79..82a9963 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -412,6 +412,9 @@
#endif
#endif
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
+TINT_BEGIN_DISABLE_WARNING(MAYBE_UNINITIALIZED);
+
/// @returns a + b, or an empty optional if the resulting value overflowed the AInt
inline std::optional<AInt> CheckedAdd(AInt a, AInt b) {
int64_t result;
@@ -582,17 +585,29 @@
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
- // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
- TINT_BEGIN_DISABLE_WARNING(MAYBE_UNINITIALIZED);
-
if (auto mul = CheckedMul(a, b)) {
return CheckedAdd(mul.value(), c);
}
return {};
-
- TINT_END_DISABLE_WARNING(MAYBE_UNINITIALIZED);
}
+/// @returns the value of `base` raised to the power `exp`, or an empty optional if the operation
+/// cannot be performed.
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedPow(FloatingPointT base, FloatingPointT exp) {
+ static_assert(IsNumber<FloatingPointT>);
+ if ((base < 0) || (base == 0 && exp <= 0)) {
+ return {};
+ }
+ auto result = FloatingPointT{std::pow(base.value, exp.value)};
+ if (!std::isfinite(result.value)) {
+ return {};
+ }
+ return result;
+}
+
+TINT_END_DISABLE_WARNING(MAYBE_UNINITIALIZED);
+
} // namespace tint
namespace tint::number_suffixes {
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
index a795840..fe03663 100644
--- a/src/tint/number_test.cc
+++ b/src/tint/number_test.cc
@@ -399,12 +399,32 @@
std::variant<std::optional<AFloat>, std::optional<f32>, std::optional<f16>>;
using BinaryCheckedCase_Float = std::tuple<FloatExpectedTypes, FloatInputTypes, FloatInputTypes>;
+/// Validates that result is equal to expect. If `float_comp` is true, uses EXPECT_FLOAT_EQ to
+/// compare the values.
+template <typename T>
+void ValidateResult(std::optional<T> result, std::optional<T> expect, bool float_comp = false) {
+ if (!expect) {
+ EXPECT_TRUE(!result) << *result;
+ } else {
+ ASSERT_TRUE(result);
+ if constexpr (IsIntegral<T>) {
+ EXPECT_EQ(*result, *expect);
+ } else {
+ if (float_comp) {
+ EXPECT_FLOAT_EQ(*result, *expect);
+ } else {
+ EXPECT_EQ(*result, *expect);
+ }
+ }
+ }
+}
+
TEST_P(CheckedAddTest_AInt, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedAdd(a, b) == expect) << std::hex << "0x" << a << " + 0x" << b;
- EXPECT_TRUE(CheckedAdd(b, a) == expect) << std::hex << "0x" << a << " + 0x" << b;
+ ValidateResult(CheckedAdd(a, b), expect);
+ ValidateResult(CheckedAdd(b, a), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedAddTest_AInt,
@@ -477,7 +497,7 @@
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedSub(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
+ ValidateResult(CheckedSub(a, b), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedSubTest_AInt,
@@ -514,8 +534,7 @@
using T = std::decay_t<decltype(lhs)>;
auto rhs = std::get<T>(std::get<2>(p));
auto expect = std::get<std::optional<T>>(std::get<0>(p));
- EXPECT_TRUE(CheckedSub(lhs, rhs) == expect)
- << std::hex << "0x" << lhs << " - 0x" << rhs;
+ ValidateResult(CheckedSub(lhs, rhs), expect);
},
std::get<1>(p));
}
@@ -546,8 +565,8 @@
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedMul(a, b) == expect) << std::hex << "0x" << a << " * 0x" << b;
- EXPECT_TRUE(CheckedMul(b, a) == expect) << std::hex << "0x" << a << " * 0x" << b;
+ ValidateResult(CheckedMul(a, b), expect);
+ ValidateResult(CheckedMul(b, a), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedMulTest_AInt,
@@ -595,10 +614,8 @@
using T = std::decay_t<decltype(lhs)>;
auto rhs = std::get<T>(std::get<2>(p));
auto expect = std::get<std::optional<T>>(std::get<0>(p));
- EXPECT_TRUE(CheckedMul(lhs, rhs) == expect)
- << std::hex << "0x" << lhs << " * 0x" << rhs;
- EXPECT_TRUE(CheckedMul(rhs, lhs) == expect)
- << std::hex << "0x" << lhs << " * 0x" << rhs;
+ ValidateResult(CheckedMul(lhs, rhs), expect);
+ ValidateResult(CheckedMul(rhs, lhs), expect);
},
std::get<1>(p));
}
@@ -628,7 +645,7 @@
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedDiv(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
+ ValidateResult(CheckedDiv(a, b), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedDivTest_AInt,
@@ -657,8 +674,7 @@
using T = std::decay_t<decltype(lhs)>;
auto rhs = std::get<T>(std::get<2>(p));
auto expect = std::get<std::optional<T>>(std::get<0>(p));
- EXPECT_TRUE(CheckedDiv(lhs, rhs) == expect)
- << std::hex << "0x" << lhs << " / 0x" << rhs;
+ ValidateResult(CheckedDiv(lhs, rhs), expect);
},
std::get<1>(p));
}
@@ -692,7 +708,7 @@
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
- EXPECT_TRUE(CheckedMod(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
+ EXPECT_TRUE(CheckedMod(a, b) == expect) << std::hex << "0x" << a << " % 0x" << b;
}
INSTANTIATE_TEST_SUITE_P(
CheckedModTest_AInt,
@@ -726,8 +742,7 @@
using T = std::decay_t<decltype(lhs)>;
auto rhs = std::get<T>(std::get<2>(p));
auto expect = std::get<std::optional<T>>(std::get<0>(p));
- EXPECT_TRUE(CheckedMod(lhs, rhs) == expect)
- << std::hex << "0x" << lhs << " / 0x" << rhs;
+ ValidateResult(CheckedMod(lhs, rhs), expect);
},
std::get<1>(p));
}
@@ -759,6 +774,51 @@
CheckedModTest_FloatCases<f32>(),
CheckedModTest_FloatCases<f16>())));
+using CheckedPowTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedPowTest_Float, Test) {
+ auto& p = GetParam();
+ std::visit(
+ [&](auto&& lhs) {
+ using T = std::decay_t<decltype(lhs)>;
+ auto rhs = std::get<T>(std::get<2>(p));
+ auto expect = std::get<std::optional<T>>(std::get<0>(p));
+ ValidateResult(CheckedPow(lhs, rhs), expect, /* float_comp */ true);
+ },
+ std::get<1>(p));
+}
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedPowTest_FloatCases() {
+ return {
+ {T(0), T(0), T(1)}, //
+ {T(0), T(0), T::Highest()}, //
+ {T(1), T(1), T(1)}, //
+ {T(1), T(1), T::Lowest()}, //
+ {T(4), T(2), T(2)}, //
+ {T(8), T(2), T(3)}, //
+ {T(1), T(1), T::Highest()}, //
+ {T(1), T(1), -T(1)}, //
+ {T(0.25), T(2), -T(2)}, //
+ {T(0.125), T(2), -T(3)}, //
+ {T(15.625), T(2.5), T(3)}, //
+ {T(11.313708498), T(2), T(3.5)}, //
+ {T(24.705294220), T(2.5), T(3.5)}, //
+ {T(0.0883883476), T(2), -T(3.5)}, //
+ {Overflow<T>, -T(1), T(1)}, //
+ {Overflow<T>, -T(1), T::Highest()}, //
+ {Overflow<T>, T::Lowest(), T(1)}, //
+ {Overflow<T>, T::Lowest(), T::Highest()}, //
+ {Overflow<T>, T::Lowest(), T::Lowest()}, //
+ {Overflow<T>, T(0), T(0)}, //
+ {Overflow<T>, T(0), -T(1)}, //
+ {Overflow<T>, T(0), T::Lowest()}, //
+ };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedPowTest_Float,
+ CheckedPowTest_Float,
+ testing::ValuesIn(Concat(CheckedPowTest_FloatCases<AFloat>(),
+ CheckedPowTest_FloatCases<f32>(),
+ CheckedPowTest_FloatCases<f16>())));
+
using TernaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt, AInt>;
using CheckedMaddTest_AInt = testing::TestWithParam<TernaryCheckedCase>;
@@ -767,10 +827,8 @@
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
auto c = std::get<3>(GetParam());
- EXPECT_EQ(CheckedMadd(a, b, c), expect)
- << std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
- EXPECT_EQ(CheckedMadd(b, a, c), expect)
- << std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
+ ValidateResult(CheckedMadd(a, b, c), expect);
+ ValidateResult(CheckedMadd(b, a, c), expect);
}
INSTANTIATE_TEST_SUITE_P(
CheckedMaddTest_AInt,