tint/number: add CheckedMod functions
Will be used to implement const eval of binary modulo.
Bug: tint:1581
Change-Id: Ib3cb422b247d57932d0b7cfc0ea8588206c39671
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112321
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/number.h b/src/tint/number.h
index dc66689..9b7ce43 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -543,6 +543,38 @@
return result;
}
+namespace detail {
+/// @returns the remainder of e1 / e2
+template <typename T>
+inline T Mod(T e1, T e2) {
+ return e1 - e2 * static_cast<T>(std::trunc(e1 / e2));
+}
+} // namespace detail
+
+/// @returns the remainder of a / b, or an empty optional if the resulting value overflowed the AInt
+inline std::optional<AInt> CheckedMod(AInt a, AInt b) {
+ if (b == 0) {
+ return {};
+ }
+
+ if (b == -1 && a == AInt::Lowest()) {
+ return {};
+ }
+
+ return AInt{detail::Mod(a.value, b.value)};
+}
+
+/// @returns the remainder of a / b, or an empty optional if the resulting value overflowed the
+/// float value
+template <typename FloatingPointT, typename = traits::EnableIf<IsFloatingPoint<FloatingPointT>>>
+inline std::optional<FloatingPointT> CheckedMod(FloatingPointT a, FloatingPointT b) {
+ auto result = FloatingPointT{detail::Mod(a.value, b.value)};
+ if (!std::isfinite(result.value)) {
+ return {};
+ }
+ return result;
+}
+
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
index f21e531..040fd76 100644
--- a/src/tint/number_test.cc
+++ b/src/tint/number_test.cc
@@ -687,6 +687,58 @@
CheckedDivTest_FloatCases<f32>(),
CheckedDivTest_FloatCases<f16>())));
+using CheckedModTest_AInt = testing::TestWithParam<BinaryCheckedCase_AInt>;
+TEST_P(CheckedModTest_AInt, Test) {
+ auto expect = std::get<0>(GetParam());
+ auto a = std::get<1>(GetParam());
+ auto b = std::get<2>(GetParam());
+ EXPECT_TRUE(CheckedMod(a, b) == expect) << std::hex << "0x" << a << " - 0x" << b;
+}
+INSTANTIATE_TEST_SUITE_P(
+ CheckedModTest_AInt,
+ CheckedModTest_AInt,
+ testing::ValuesIn(std::vector<BinaryCheckedCase_AInt>{
+ {AInt(0), AInt(0), AInt(1)},
+ {AInt(0), AInt(1), AInt(1)},
+ {AInt(1), AInt(10), AInt(3)},
+ {AInt(2), AInt(10), AInt(4)},
+ {AInt(0), AInt::Highest(), AInt::Highest()},
+ {AInt(0), AInt::Lowest(), AInt::Lowest()},
+ {OVERFLOW, AInt::Highest(), AInt(0)},
+ {OVERFLOW, AInt::Lowest(), AInt(0)},
+ ////////////////////////////////////////////////////////////////////////
+ }));
+
+using CheckedModTest_Float = testing::TestWithParam<BinaryCheckedCase_Float>;
+TEST_P(CheckedModTest_Float, Test) {
+ auto& p = GetParam();
+ std::visit(
+ [&](auto&& lhs) {
+ using T = std::decay_t<decltype(lhs)>;
+ auto rhs = std::get<T>(std::get<2>(p));
+ auto expect = std::get<std::optional<T>>(std::get<0>(p));
+ EXPECT_TRUE(CheckedMod(lhs, rhs) == expect)
+ << std::hex << "0x" << lhs << " / 0x" << rhs;
+ },
+ std::get<1>(p));
+}
+template <typename T>
+std::vector<BinaryCheckedCase_Float> CheckedModTest_FloatCases() {
+ return {
+ {T(0.5), T(10.5), T(1)}, {T(0.5), T(10.5), T(2)},
+ {T(1.5), T(10.5), T(3)}, {T(2.5), T(10.5), T(4)},
+ {T(0.5), T(10.5), T(5)}, {T(0), T::Highest(), T::Highest()},
+ {T(0), T::Lowest(), T::Lowest()}, {Overflow<T>, T(123), T(0)},
+ {Overflow<T>, T(123), T(-0)}, {Overflow<T>, T(-123), T(0)},
+ {Overflow<T>, T(-123), T(-0)},
+ };
+}
+INSTANTIATE_TEST_SUITE_P(CheckedModTest_Float,
+ CheckedModTest_Float,
+ testing::ValuesIn(Concat(CheckedModTest_FloatCases<AFloat>(),
+ CheckedModTest_FloatCases<f32>(),
+ CheckedModTest_FloatCases<f16>())));
+
using TernaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt, AInt>;
using CheckedMaddTest_AInt = testing::TestWithParam<TernaryCheckedCase>;