tint: Add Checked[Add|Mul|Madd]()
Test-for-overflow utilities for AInt.
Bug: tint:1504
Change-Id: I974ef829c72aaa4c2012550855227f71d4a370a0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91700
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
diff --git a/src/tint/number.h b/src/tint/number.h
index 6efb023..b4c5ca4 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -19,7 +19,10 @@
#include <functional>
#include <limits>
#include <ostream>
+// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
+#include <optional> // NOLINT(build/include_order))
+#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/result.h"
// Forward declaration
@@ -184,33 +187,6 @@
return !(a == b);
}
-/// Enumerator of failure reasons when converting from one number to another.
-enum class ConversionFailure {
- kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
- kExceedsNegativeLimit, // The value was too big (-'ve) to fit in the target type
-};
-
-/// Writes the conversion failure message to the ostream.
-/// @param out the std::ostream to write to
-/// @param failure the ConversionFailure
-/// @return the std::ostream so calls can be chained
-std::ostream& operator<<(std::ostream& out, ConversionFailure failure);
-
-/// Converts a number from one type to another, checking that the value fits in the target type.
-/// @returns the resulting value of the conversion, or a failure reason.
-template <typename TO, typename FROM>
-utils::Result<TO, ConversionFailure> CheckedConvert(Number<FROM> num) {
- using T = decltype(UnwrapNumber<TO>() + num.value);
- const auto value = static_cast<T>(num.value);
- if (value > static_cast<T>(TO::kHighest)) {
- return ConversionFailure::kExceedsPositiveLimit;
- }
- if (value < static_cast<T>(TO::kLowest)) {
- return ConversionFailure::kExceedsNegativeLimit;
- }
- return TO(value); // Success
-}
-
/// The partial specification of Number for f16 type, storing the f16 value as float,
/// and enforcing proper explicit casting.
template <>
@@ -282,6 +258,114 @@
/// However since C++ don't have native binary16 type, the value is stored as float.
using f16 = Number<detail::NumberKindF16>;
+/// Enumerator of failure reasons when converting from one number to another.
+enum class ConversionFailure {
+ kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
+ kExceedsNegativeLimit, // The value was too big (-'ve) to fit in the target type
+};
+
+/// Writes the conversion failure message to the ostream.
+/// @param out the std::ostream to write to
+/// @param failure the ConversionFailure
+/// @return the std::ostream so calls can be chained
+std::ostream& operator<<(std::ostream& out, ConversionFailure failure);
+
+/// Converts a number from one type to another, checking that the value fits in the target type.
+/// @returns the resulting value of the conversion, or a failure reason.
+template <typename TO, typename FROM>
+utils::Result<TO, ConversionFailure> CheckedConvert(Number<FROM> num) {
+ using T = decltype(UnwrapNumber<TO>() + num.value);
+ const auto value = static_cast<T>(num.value);
+ if (value > static_cast<T>(TO::kHighest)) {
+ return ConversionFailure::kExceedsPositiveLimit;
+ }
+ if (value < static_cast<T>(TO::kLowest)) {
+ return ConversionFailure::kExceedsNegativeLimit;
+ }
+ return TO(value); // Success
+}
+
+/// Define 'TINT_HAS_OVERFLOW_BUILTINS' if the compiler provide overflow checking builtins.
+/// If the compiler does not support these builtins, then these are emulated with algorithms
+/// described in:
+/// https://wiki.sei.cmu.edu/confluence/display/c/INT32-C.+Ensure+that+operations+on+signed+integers+do+not+result+in+overflow
+#if defined(__GNUC__) && __GNUC__ >= 5
+#define TINT_HAS_OVERFLOW_BUILTINS
+#elif defined(__clang__)
+#if __has_builtin(__builtin_add_overflow) && __has_builtin(__builtin_mul_overflow)
+#define TINT_HAS_OVERFLOW_BUILTINS
+#endif
+#endif
+
+/// @returns a + b, or an empty optional if the resulting value overflowed the AInt
+inline std::optional<AInt> CheckedAdd(AInt a, AInt b) {
+ int64_t result;
+#ifdef TINT_HAS_OVERFLOW_BUILTINS
+ if (__builtin_add_overflow(a.value, b.value, &result)) {
+ return {};
+ }
+#else // TINT_HAS_OVERFLOW_BUILTINS
+ if (a.value >= 0) {
+ if (AInt::kHighest - a.value < b.value) {
+ return {};
+ }
+ } else {
+ if (b.value < AInt::kLowest - a.value) {
+ return {};
+ }
+ }
+ result = a.value + b.value;
+#endif // TINT_HAS_OVERFLOW_BUILTINS
+ return AInt(result);
+}
+
+/// @returns a * b, or an empty optional if the resulting value overflowed the AInt
+inline std::optional<AInt> CheckedMul(AInt a, AInt b) {
+ int64_t result;
+#ifdef TINT_HAS_OVERFLOW_BUILTINS
+ if (__builtin_mul_overflow(a.value, b.value, &result)) {
+ return {};
+ }
+#else // TINT_HAS_OVERFLOW_BUILTINS
+ if (a > 0) {
+ if (b > 0) {
+ if (a > (AInt::kHighest / b)) {
+ return {};
+ }
+ } else {
+ if (b < (AInt::kLowest / a)) {
+ return {};
+ }
+ }
+ } else {
+ if (b > 0) {
+ if (a < (AInt::kLowest / b)) {
+ return {};
+ }
+ } else {
+ if ((a != 0) && (b < (AInt::kHighest / a))) {
+ return {};
+ }
+ }
+ }
+ result = a.value * b.value;
+#endif // TINT_HAS_OVERFLOW_BUILTINS
+ return AInt(result);
+}
+
+/// @returns a * b + c, or an empty optional if the value overflowed the AInt
+inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
+ // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
+ TINT_BEGIN_DISABLE_WARNING(MAYBE_UNINITIALIZED);
+
+ if (auto mul = CheckedMul(a, b)) {
+ return CheckedAdd(mul.value(), c);
+ }
+ return {};
+
+ TINT_END_DISABLE_WARNING(MAYBE_UNINITIALIZED);
+}
+
} // namespace tint
namespace tint::number_suffixes {
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
index a7bb2fe..34b4d39 100644
--- a/src/tint/number_test.cc
+++ b/src/tint/number_test.cc
@@ -13,6 +13,8 @@
// limitations under the License.
#include <cmath>
+#include <tuple>
+#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/utils/compiler_macros.h"
@@ -141,6 +143,165 @@
EXPECT_TRUE(std::isnan(f16(nan)));
}
+using BinaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt>;
+
+#undef OVERFLOW // corecrt_math.h :(
+#define OVERFLOW \
+ {}
+
+using CheckedAddTest = testing::TestWithParam<BinaryCheckedCase>;
+TEST_P(CheckedAddTest, Test) {
+ auto expect = std::get<0>(GetParam());
+ auto a = std::get<1>(GetParam());
+ auto b = std::get<2>(GetParam());
+ EXPECT_EQ(CheckedAdd(a, b), expect) << std::hex << "0x" << a << " * 0x" << b;
+ EXPECT_EQ(CheckedAdd(b, a), expect) << std::hex << "0x" << a << " * 0x" << b;
+}
+INSTANTIATE_TEST_SUITE_P(
+ CheckedAddTest,
+ CheckedAddTest,
+ testing::ValuesIn(std::vector<BinaryCheckedCase>{
+ {AInt(0), AInt(0), AInt(0)},
+ {AInt(1), AInt(1), AInt(0)},
+ {AInt(2), AInt(1), AInt(1)},
+ {AInt(0), AInt(-1), AInt(1)},
+ {AInt(3), AInt(2), AInt(1)},
+ {AInt(-1), AInt(-2), AInt(1)},
+ {AInt(0x300), AInt(0x100), AInt(0x200)},
+ {AInt(0x100), AInt(-0x100), AInt(0x200)},
+ {AInt(AInt::kHighest), AInt(1), AInt(AInt::kHighest - 1)},
+ {AInt(AInt::kLowest), AInt(-1), AInt(AInt::kLowest + 1)},
+ {AInt(AInt::kHighest), AInt(0x7fffffff00000000ll), AInt(0x00000000ffffffffll)},
+ {AInt(AInt::kHighest), AInt(AInt::kHighest), AInt(0)},
+ {AInt(AInt::kLowest), AInt(AInt::kLowest), AInt(0)},
+ {OVERFLOW, AInt(1), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(-1), AInt(AInt::kLowest)},
+ {OVERFLOW, AInt(2), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(-2), AInt(AInt::kLowest)},
+ {OVERFLOW, AInt(10000), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(-10000), AInt(AInt::kLowest)},
+ {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(AInt::kLowest), AInt(AInt::kLowest)},
+ ////////////////////////////////////////////////////////////////////////
+ }));
+
+using CheckedMulTest = testing::TestWithParam<BinaryCheckedCase>;
+TEST_P(CheckedMulTest, Test) {
+ auto expect = std::get<0>(GetParam());
+ auto a = std::get<1>(GetParam());
+ auto b = std::get<2>(GetParam());
+ EXPECT_EQ(CheckedMul(a, b), expect) << std::hex << "0x" << a << " * 0x" << b;
+ EXPECT_EQ(CheckedMul(b, a), expect) << std::hex << "0x" << a << " * 0x" << b;
+}
+INSTANTIATE_TEST_SUITE_P(
+ CheckedMulTest,
+ CheckedMulTest,
+ testing::ValuesIn(std::vector<BinaryCheckedCase>{
+ {AInt(0), AInt(0), AInt(0)},
+ {AInt(0), AInt(1), AInt(0)},
+ {AInt(1), AInt(1), AInt(1)},
+ {AInt(-1), AInt(-1), AInt(1)},
+ {AInt(2), AInt(2), AInt(1)},
+ {AInt(-2), AInt(-2), AInt(1)},
+ {AInt(0x20000), AInt(0x100), AInt(0x200)},
+ {AInt(-0x20000), AInt(-0x100), AInt(0x200)},
+ {AInt(0x4000000000000000ll), AInt(0x80000000ll), AInt(0x80000000ll)},
+ {AInt(0x4000000000000000ll), AInt(-0x80000000ll), AInt(-0x80000000ll)},
+ {AInt(0x1000000000000000ll), AInt(0x40000000ll), AInt(0x40000000ll)},
+ {AInt(-0x1000000000000000ll), AInt(-0x40000000ll), AInt(0x40000000ll)},
+ {AInt(0x100000000000000ll), AInt(0x1000000), AInt(0x100000000ll)},
+ {AInt(0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(2)},
+ {AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2)},
+ {AInt(-0x2000000000000000ll), AInt(-0x1000000000000000ll), AInt(2)},
+ {AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2)},
+ {AInt(0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(4)},
+ {AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4)},
+ {AInt(-0x4000000000000000ll), AInt(-0x1000000000000000ll), AInt(4)},
+ {AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4)},
+ {AInt(-0x8000000000000000ll), AInt(0x1000000000000000ll), AInt(-8)},
+ {AInt(-0x8000000000000000ll), AInt(-0x1000000000000000ll), AInt(8)},
+ {AInt(0), AInt(AInt::kHighest), AInt(0)},
+ {AInt(0), AInt(AInt::kLowest), AInt(0)},
+ {OVERFLOW, AInt(0x1000000000000000ll), AInt(8)},
+ {OVERFLOW, AInt(-0x1000000000000000ll), AInt(-8)},
+ {OVERFLOW, AInt(0x800000000000000ll), AInt(0x10)},
+ {OVERFLOW, AInt(0x80000000ll), AInt(0x100000000ll)},
+ {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kLowest)},
+ ////////////////////////////////////////////////////////////////////////
+ }));
+
+using TernaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt, AInt>;
+
+using CheckedMaddTest = testing::TestWithParam<TernaryCheckedCase>;
+TEST_P(CheckedMaddTest, Test) {
+ auto expect = std::get<0>(GetParam());
+ auto a = std::get<1>(GetParam());
+ auto b = std::get<2>(GetParam());
+ auto c = std::get<3>(GetParam());
+ EXPECT_EQ(CheckedMadd(a, b, c), expect)
+ << std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
+ EXPECT_EQ(CheckedMadd(b, a, c), expect)
+ << std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
+}
+INSTANTIATE_TEST_SUITE_P(
+ CheckedMaddTest,
+ CheckedMaddTest,
+ testing::ValuesIn(std::vector<TernaryCheckedCase>{
+ {AInt(0), AInt(0), AInt(0), AInt(0)},
+ {AInt(0), AInt(1), AInt(0), AInt(0)},
+ {AInt(1), AInt(1), AInt(1), AInt(0)},
+ {AInt(2), AInt(1), AInt(1), AInt(1)},
+ {AInt(0), AInt(1), AInt(-1), AInt(1)},
+ {AInt(-1), AInt(1), AInt(-2), AInt(1)},
+ {AInt(-1), AInt(-1), AInt(1), AInt(0)},
+ {AInt(2), AInt(2), AInt(1), AInt(0)},
+ {AInt(-2), AInt(-2), AInt(1), AInt(0)},
+ {AInt(0), AInt(AInt::kHighest), AInt(0), AInt(0)},
+ {AInt(0), AInt(AInt::kLowest), AInt(0), AInt(0)},
+ {AInt(3), AInt(1), AInt(2), AInt(1)},
+ {AInt(0x300), AInt(1), AInt(0x100), AInt(0x200)},
+ {AInt(0x100), AInt(1), AInt(-0x100), AInt(0x200)},
+ {AInt(0x20000), AInt(0x100), AInt(0x200), AInt(0)},
+ {AInt(-0x20000), AInt(-0x100), AInt(0x200), AInt(0)},
+ {AInt(0x4000000000000000ll), AInt(0x80000000ll), AInt(0x80000000ll), AInt(0)},
+ {AInt(0x4000000000000000ll), AInt(-0x80000000ll), AInt(-0x80000000ll), AInt(0)},
+ {AInt(0x1000000000000000ll), AInt(0x40000000ll), AInt(0x40000000ll), AInt(0)},
+ {AInt(-0x1000000000000000ll), AInt(-0x40000000ll), AInt(0x40000000ll), AInt(0)},
+ {AInt(0x100000000000000ll), AInt(0x1000000), AInt(0x100000000ll), AInt(0)},
+ {AInt(0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(2), AInt(0)},
+ {AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2), AInt(0)},
+ {AInt(-0x2000000000000000ll), AInt(-0x1000000000000000ll), AInt(2), AInt(0)},
+ {AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2), AInt(0)},
+ {AInt(0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(4), AInt(0)},
+ {AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4), AInt(0)},
+ {AInt(-0x4000000000000000ll), AInt(-0x1000000000000000ll), AInt(4), AInt(0)},
+ {AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4), AInt(0)},
+ {AInt(-0x8000000000000000ll), AInt(0x1000000000000000ll), AInt(-8), AInt(0)},
+ {AInt(-0x8000000000000000ll), AInt(-0x1000000000000000ll), AInt(8), AInt(0)},
+ {AInt(AInt::kHighest), AInt(1), AInt(1), AInt(AInt::kHighest - 1)},
+ {AInt(AInt::kLowest), AInt(1), AInt(-1), AInt(AInt::kLowest + 1)},
+ {AInt(AInt::kHighest), AInt(1), AInt(0x7fffffff00000000ll), AInt(0x00000000ffffffffll)},
+ {AInt(AInt::kHighest), AInt(1), AInt(AInt::kHighest), AInt(0)},
+ {AInt(AInt::kLowest), AInt(1), AInt(AInt::kLowest), AInt(0)},
+ {OVERFLOW, AInt(0x1000000000000000ll), AInt(8), AInt(0)},
+ {OVERFLOW, AInt(-0x1000000000000000ll), AInt(-8), AInt(0)},
+ {OVERFLOW, AInt(0x800000000000000ll), AInt(0x10), AInt(0)},
+ {OVERFLOW, AInt(0x80000000ll), AInt(0x100000000ll), AInt(0)},
+ {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest), AInt(0)},
+ {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kLowest), AInt(0)},
+ {OVERFLOW, AInt(1), AInt(1), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(1), AInt(-1), AInt(AInt::kLowest)},
+ {OVERFLOW, AInt(1), AInt(2), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(1), AInt(-2), AInt(AInt::kLowest)},
+ {OVERFLOW, AInt(1), AInt(10000), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(1), AInt(-10000), AInt(AInt::kLowest)},
+ {OVERFLOW, AInt(1), AInt(AInt::kHighest), AInt(AInt::kHighest)},
+ {OVERFLOW, AInt(1), AInt(AInt::kLowest), AInt(AInt::kLowest)},
+ {OVERFLOW, AInt(1), AInt(AInt::kHighest), AInt(1)},
+ {OVERFLOW, AInt(1), AInt(AInt::kLowest), AInt(-1)},
+ }));
+
TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
} // namespace
diff --git a/src/tint/utils/compiler_macros.h b/src/tint/utils/compiler_macros.h
index ada138e..34965c6 100644
--- a/src/tint/utils/compiler_macros.h
+++ b/src/tint/utils/compiler_macros.h
@@ -20,23 +20,63 @@
#define TINT_REQUIRE_SEMICOLON static_assert(true)
#if defined(_MSC_VER)
-#define TINT_WARNING_UNREACHABLE_CODE 4702
-#define TINT_WARNING_CONSTANT_OVERFLOW 4756
+////////////////////////////////////////////////////////////////////////////////
+// MSVC
+////////////////////////////////////////////////////////////////////////////////
+#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW __pragma(warning(disable : 4756))
+#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED /* currently no-op */
+#define TINT_DISABLE_WARNING_UNREACHABLE_CODE __pragma(warning(disable : 4702))
// clang-format off
-#define TINT_BEGIN_DISABLE_WARNING(name) \
- __pragma(warning(push)) \
- __pragma(warning(disable:TINT_CONCAT(TINT_WARNING_, name))) \
+#define TINT_BEGIN_DISABLE_WARNING(name) \
+ __pragma(warning(push)) \
+ TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
TINT_REQUIRE_SEMICOLON
-#define TINT_END_DISABLE_WARNING(name) \
- __pragma(warning(pop)) \
+#define TINT_END_DISABLE_WARNING(name) \
+ __pragma(warning(pop)) \
+ TINT_REQUIRE_SEMICOLON
+// clang-format on
+#elif defined(__clang__)
+////////////////////////////////////////////////////////////////////////////////
+// Clang
+////////////////////////////////////////////////////////////////////////////////
+#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW /* currently no-op */
+#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED /* currently no-op */
+#define TINT_DISABLE_WARNING_UNREACHABLE_CODE /* currently no-op */
+
+// clang-format off
+#define TINT_BEGIN_DISABLE_WARNING(name) \
+ _Pragma("clang diagnostic push") \
+ TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
+ TINT_REQUIRE_SEMICOLON
+#define TINT_END_DISABLE_WARNING(name) \
+ _Pragma("clang diagnostic pop") \
+ TINT_REQUIRE_SEMICOLON
+// clang-format on
+#elif defined(__GNUC__)
+////////////////////////////////////////////////////////////////////////////////
+// GCC
+////////////////////////////////////////////////////////////////////////////////
+#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW /* currently no-op */
+#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED \
+ _Pragma("GCC diagnostic ignored \"-Wmaybe-uninitialized\"")
+#define TINT_DISABLE_WARNING_UNREACHABLE_CODE /* currently no-op */
+
+// clang-format off
+#define TINT_BEGIN_DISABLE_WARNING(name) \
+ _Pragma("GCC diagnostic push") \
+ TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
+ TINT_REQUIRE_SEMICOLON
+#define TINT_END_DISABLE_WARNING(name) \
+ _Pragma("GCC diagnostic pop") \
TINT_REQUIRE_SEMICOLON
// clang-format on
#else
-// clang-format off
+////////////////////////////////////////////////////////////////////////////////
+// Other
+////////////////////////////////////////////////////////////////////////////////
#define TINT_BEGIN_DISABLE_WARNING(name) TINT_REQUIRE_SEMICOLON
#define TINT_END_DISABLE_WARNING(name) TINT_REQUIRE_SEMICOLON
-// clang-format on
-#endif // defined(_MSC_VER)
+#endif
#endif // SRC_TINT_UTILS_COMPILER_MACROS_H_