tint: Add Checked[Add|Mul|Madd]()

Test-for-overflow utilities for AInt.

Bug: tint:1504
Change-Id: I974ef829c72aaa4c2012550855227f71d4a370a0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91700
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
diff --git a/src/tint/number.h b/src/tint/number.h
index 6efb023..b4c5ca4 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -19,7 +19,10 @@
 #include <functional>
 #include <limits>
 #include <ostream>
+// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
+#include <optional>  // NOLINT(build/include_order))
 
+#include "src/tint/utils/compiler_macros.h"
 #include "src/tint/utils/result.h"
 
 // Forward declaration
@@ -184,33 +187,6 @@
     return !(a == b);
 }
 
-/// Enumerator of failure reasons when converting from one number to another.
-enum class ConversionFailure {
-    kExceedsPositiveLimit,  // The value was too big (+'ve) to fit in the target type
-    kExceedsNegativeLimit,  // The value was too big (-'ve) to fit in the target type
-};
-
-/// Writes the conversion failure message to the ostream.
-/// @param out the std::ostream to write to
-/// @param failure the ConversionFailure
-/// @return the std::ostream so calls can be chained
-std::ostream& operator<<(std::ostream& out, ConversionFailure failure);
-
-/// Converts a number from one type to another, checking that the value fits in the target type.
-/// @returns the resulting value of the conversion, or a failure reason.
-template <typename TO, typename FROM>
-utils::Result<TO, ConversionFailure> CheckedConvert(Number<FROM> num) {
-    using T = decltype(UnwrapNumber<TO>() + num.value);
-    const auto value = static_cast<T>(num.value);
-    if (value > static_cast<T>(TO::kHighest)) {
-        return ConversionFailure::kExceedsPositiveLimit;
-    }
-    if (value < static_cast<T>(TO::kLowest)) {
-        return ConversionFailure::kExceedsNegativeLimit;
-    }
-    return TO(value);  // Success
-}
-
 /// The partial specification of Number for f16 type, storing the f16 value as float,
 /// and enforcing proper explicit casting.
 template <>
@@ -282,6 +258,114 @@
 /// However since C++ don't have native binary16 type, the value is stored as float.
 using f16 = Number<detail::NumberKindF16>;
 
+/// Enumerator of failure reasons when converting from one number to another.
+enum class ConversionFailure {
+    kExceedsPositiveLimit,  // The value was too big (+'ve) to fit in the target type
+    kExceedsNegativeLimit,  // The value was too big (-'ve) to fit in the target type
+};
+
+/// Writes the conversion failure message to the ostream.
+/// @param out the std::ostream to write to
+/// @param failure the ConversionFailure
+/// @return the std::ostream so calls can be chained
+std::ostream& operator<<(std::ostream& out, ConversionFailure failure);
+
+/// Converts a number from one type to another, checking that the value fits in the target type.
+/// @returns the resulting value of the conversion, or a failure reason.
+template <typename TO, typename FROM>
+utils::Result<TO, ConversionFailure> CheckedConvert(Number<FROM> num) {
+    using T = decltype(UnwrapNumber<TO>() + num.value);
+    const auto value = static_cast<T>(num.value);
+    if (value > static_cast<T>(TO::kHighest)) {
+        return ConversionFailure::kExceedsPositiveLimit;
+    }
+    if (value < static_cast<T>(TO::kLowest)) {
+        return ConversionFailure::kExceedsNegativeLimit;
+    }
+    return TO(value);  // Success
+}
+
+/// Define 'TINT_HAS_OVERFLOW_BUILTINS' if the compiler provide overflow checking builtins.
+/// If the compiler does not support these builtins, then these are emulated with algorithms
+/// described in:
+/// https://wiki.sei.cmu.edu/confluence/display/c/INT32-C.+Ensure+that+operations+on+signed+integers+do+not+result+in+overflow
+#if defined(__GNUC__) && __GNUC__ >= 5
+#define TINT_HAS_OVERFLOW_BUILTINS
+#elif defined(__clang__)
+#if __has_builtin(__builtin_add_overflow) && __has_builtin(__builtin_mul_overflow)
+#define TINT_HAS_OVERFLOW_BUILTINS
+#endif
+#endif
+
+/// @returns a + b, or an empty optional if the resulting value overflowed the AInt
+inline std::optional<AInt> CheckedAdd(AInt a, AInt b) {
+    int64_t result;
+#ifdef TINT_HAS_OVERFLOW_BUILTINS
+    if (__builtin_add_overflow(a.value, b.value, &result)) {
+        return {};
+    }
+#else   // TINT_HAS_OVERFLOW_BUILTINS
+    if (a.value >= 0) {
+        if (AInt::kHighest - a.value < b.value) {
+            return {};
+        }
+    } else {
+        if (b.value < AInt::kLowest - a.value) {
+            return {};
+        }
+    }
+    result = a.value + b.value;
+#endif  // TINT_HAS_OVERFLOW_BUILTINS
+    return AInt(result);
+}
+
+/// @returns a * b, or an empty optional if the resulting value overflowed the AInt
+inline std::optional<AInt> CheckedMul(AInt a, AInt b) {
+    int64_t result;
+#ifdef TINT_HAS_OVERFLOW_BUILTINS
+    if (__builtin_mul_overflow(a.value, b.value, &result)) {
+        return {};
+    }
+#else   // TINT_HAS_OVERFLOW_BUILTINS
+    if (a > 0) {
+        if (b > 0) {
+            if (a > (AInt::kHighest / b)) {
+                return {};
+            }
+        } else {
+            if (b < (AInt::kLowest / a)) {
+                return {};
+            }
+        }
+    } else {
+        if (b > 0) {
+            if (a < (AInt::kLowest / b)) {
+                return {};
+            }
+        } else {
+            if ((a != 0) && (b < (AInt::kHighest / a))) {
+                return {};
+            }
+        }
+    }
+    result = a.value * b.value;
+#endif  // TINT_HAS_OVERFLOW_BUILTINS
+    return AInt(result);
+}
+
+/// @returns a * b + c, or an empty optional if the value overflowed the AInt
+inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
+    // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
+    TINT_BEGIN_DISABLE_WARNING(MAYBE_UNINITIALIZED);
+
+    if (auto mul = CheckedMul(a, b)) {
+        return CheckedAdd(mul.value(), c);
+    }
+    return {};
+
+    TINT_END_DISABLE_WARNING(MAYBE_UNINITIALIZED);
+}
+
 }  // namespace tint
 
 namespace tint::number_suffixes {
diff --git a/src/tint/number_test.cc b/src/tint/number_test.cc
index a7bb2fe..34b4d39 100644
--- a/src/tint/number_test.cc
+++ b/src/tint/number_test.cc
@@ -13,6 +13,8 @@
 // limitations under the License.
 
 #include <cmath>
+#include <tuple>
+#include <vector>
 
 #include "src/tint/program_builder.h"
 #include "src/tint/utils/compiler_macros.h"
@@ -141,6 +143,165 @@
     EXPECT_TRUE(std::isnan(f16(nan)));
 }
 
+using BinaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt>;
+
+#undef OVERFLOW  // corecrt_math.h :(
+#define OVERFLOW \
+    {}
+
+using CheckedAddTest = testing::TestWithParam<BinaryCheckedCase>;
+TEST_P(CheckedAddTest, Test) {
+    auto expect = std::get<0>(GetParam());
+    auto a = std::get<1>(GetParam());
+    auto b = std::get<2>(GetParam());
+    EXPECT_EQ(CheckedAdd(a, b), expect) << std::hex << "0x" << a << " * 0x" << b;
+    EXPECT_EQ(CheckedAdd(b, a), expect) << std::hex << "0x" << a << " * 0x" << b;
+}
+INSTANTIATE_TEST_SUITE_P(
+    CheckedAddTest,
+    CheckedAddTest,
+    testing::ValuesIn(std::vector<BinaryCheckedCase>{
+        {AInt(0), AInt(0), AInt(0)},
+        {AInt(1), AInt(1), AInt(0)},
+        {AInt(2), AInt(1), AInt(1)},
+        {AInt(0), AInt(-1), AInt(1)},
+        {AInt(3), AInt(2), AInt(1)},
+        {AInt(-1), AInt(-2), AInt(1)},
+        {AInt(0x300), AInt(0x100), AInt(0x200)},
+        {AInt(0x100), AInt(-0x100), AInt(0x200)},
+        {AInt(AInt::kHighest), AInt(1), AInt(AInt::kHighest - 1)},
+        {AInt(AInt::kLowest), AInt(-1), AInt(AInt::kLowest + 1)},
+        {AInt(AInt::kHighest), AInt(0x7fffffff00000000ll), AInt(0x00000000ffffffffll)},
+        {AInt(AInt::kHighest), AInt(AInt::kHighest), AInt(0)},
+        {AInt(AInt::kLowest), AInt(AInt::kLowest), AInt(0)},
+        {OVERFLOW, AInt(1), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(-1), AInt(AInt::kLowest)},
+        {OVERFLOW, AInt(2), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(-2), AInt(AInt::kLowest)},
+        {OVERFLOW, AInt(10000), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(-10000), AInt(AInt::kLowest)},
+        {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(AInt::kLowest), AInt(AInt::kLowest)},
+        ////////////////////////////////////////////////////////////////////////
+    }));
+
+using CheckedMulTest = testing::TestWithParam<BinaryCheckedCase>;
+TEST_P(CheckedMulTest, Test) {
+    auto expect = std::get<0>(GetParam());
+    auto a = std::get<1>(GetParam());
+    auto b = std::get<2>(GetParam());
+    EXPECT_EQ(CheckedMul(a, b), expect) << std::hex << "0x" << a << " * 0x" << b;
+    EXPECT_EQ(CheckedMul(b, a), expect) << std::hex << "0x" << a << " * 0x" << b;
+}
+INSTANTIATE_TEST_SUITE_P(
+    CheckedMulTest,
+    CheckedMulTest,
+    testing::ValuesIn(std::vector<BinaryCheckedCase>{
+        {AInt(0), AInt(0), AInt(0)},
+        {AInt(0), AInt(1), AInt(0)},
+        {AInt(1), AInt(1), AInt(1)},
+        {AInt(-1), AInt(-1), AInt(1)},
+        {AInt(2), AInt(2), AInt(1)},
+        {AInt(-2), AInt(-2), AInt(1)},
+        {AInt(0x20000), AInt(0x100), AInt(0x200)},
+        {AInt(-0x20000), AInt(-0x100), AInt(0x200)},
+        {AInt(0x4000000000000000ll), AInt(0x80000000ll), AInt(0x80000000ll)},
+        {AInt(0x4000000000000000ll), AInt(-0x80000000ll), AInt(-0x80000000ll)},
+        {AInt(0x1000000000000000ll), AInt(0x40000000ll), AInt(0x40000000ll)},
+        {AInt(-0x1000000000000000ll), AInt(-0x40000000ll), AInt(0x40000000ll)},
+        {AInt(0x100000000000000ll), AInt(0x1000000), AInt(0x100000000ll)},
+        {AInt(0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(2)},
+        {AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2)},
+        {AInt(-0x2000000000000000ll), AInt(-0x1000000000000000ll), AInt(2)},
+        {AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2)},
+        {AInt(0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(4)},
+        {AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4)},
+        {AInt(-0x4000000000000000ll), AInt(-0x1000000000000000ll), AInt(4)},
+        {AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4)},
+        {AInt(-0x8000000000000000ll), AInt(0x1000000000000000ll), AInt(-8)},
+        {AInt(-0x8000000000000000ll), AInt(-0x1000000000000000ll), AInt(8)},
+        {AInt(0), AInt(AInt::kHighest), AInt(0)},
+        {AInt(0), AInt(AInt::kLowest), AInt(0)},
+        {OVERFLOW, AInt(0x1000000000000000ll), AInt(8)},
+        {OVERFLOW, AInt(-0x1000000000000000ll), AInt(-8)},
+        {OVERFLOW, AInt(0x800000000000000ll), AInt(0x10)},
+        {OVERFLOW, AInt(0x80000000ll), AInt(0x100000000ll)},
+        {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kLowest)},
+        ////////////////////////////////////////////////////////////////////////
+    }));
+
+using TernaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt, AInt>;
+
+using CheckedMaddTest = testing::TestWithParam<TernaryCheckedCase>;
+TEST_P(CheckedMaddTest, Test) {
+    auto expect = std::get<0>(GetParam());
+    auto a = std::get<1>(GetParam());
+    auto b = std::get<2>(GetParam());
+    auto c = std::get<3>(GetParam());
+    EXPECT_EQ(CheckedMadd(a, b, c), expect)
+        << std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
+    EXPECT_EQ(CheckedMadd(b, a, c), expect)
+        << std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
+}
+INSTANTIATE_TEST_SUITE_P(
+    CheckedMaddTest,
+    CheckedMaddTest,
+    testing::ValuesIn(std::vector<TernaryCheckedCase>{
+        {AInt(0), AInt(0), AInt(0), AInt(0)},
+        {AInt(0), AInt(1), AInt(0), AInt(0)},
+        {AInt(1), AInt(1), AInt(1), AInt(0)},
+        {AInt(2), AInt(1), AInt(1), AInt(1)},
+        {AInt(0), AInt(1), AInt(-1), AInt(1)},
+        {AInt(-1), AInt(1), AInt(-2), AInt(1)},
+        {AInt(-1), AInt(-1), AInt(1), AInt(0)},
+        {AInt(2), AInt(2), AInt(1), AInt(0)},
+        {AInt(-2), AInt(-2), AInt(1), AInt(0)},
+        {AInt(0), AInt(AInt::kHighest), AInt(0), AInt(0)},
+        {AInt(0), AInt(AInt::kLowest), AInt(0), AInt(0)},
+        {AInt(3), AInt(1), AInt(2), AInt(1)},
+        {AInt(0x300), AInt(1), AInt(0x100), AInt(0x200)},
+        {AInt(0x100), AInt(1), AInt(-0x100), AInt(0x200)},
+        {AInt(0x20000), AInt(0x100), AInt(0x200), AInt(0)},
+        {AInt(-0x20000), AInt(-0x100), AInt(0x200), AInt(0)},
+        {AInt(0x4000000000000000ll), AInt(0x80000000ll), AInt(0x80000000ll), AInt(0)},
+        {AInt(0x4000000000000000ll), AInt(-0x80000000ll), AInt(-0x80000000ll), AInt(0)},
+        {AInt(0x1000000000000000ll), AInt(0x40000000ll), AInt(0x40000000ll), AInt(0)},
+        {AInt(-0x1000000000000000ll), AInt(-0x40000000ll), AInt(0x40000000ll), AInt(0)},
+        {AInt(0x100000000000000ll), AInt(0x1000000), AInt(0x100000000ll), AInt(0)},
+        {AInt(0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(2), AInt(0)},
+        {AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2), AInt(0)},
+        {AInt(-0x2000000000000000ll), AInt(-0x1000000000000000ll), AInt(2), AInt(0)},
+        {AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2), AInt(0)},
+        {AInt(0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(4), AInt(0)},
+        {AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4), AInt(0)},
+        {AInt(-0x4000000000000000ll), AInt(-0x1000000000000000ll), AInt(4), AInt(0)},
+        {AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4), AInt(0)},
+        {AInt(-0x8000000000000000ll), AInt(0x1000000000000000ll), AInt(-8), AInt(0)},
+        {AInt(-0x8000000000000000ll), AInt(-0x1000000000000000ll), AInt(8), AInt(0)},
+        {AInt(AInt::kHighest), AInt(1), AInt(1), AInt(AInt::kHighest - 1)},
+        {AInt(AInt::kLowest), AInt(1), AInt(-1), AInt(AInt::kLowest + 1)},
+        {AInt(AInt::kHighest), AInt(1), AInt(0x7fffffff00000000ll), AInt(0x00000000ffffffffll)},
+        {AInt(AInt::kHighest), AInt(1), AInt(AInt::kHighest), AInt(0)},
+        {AInt(AInt::kLowest), AInt(1), AInt(AInt::kLowest), AInt(0)},
+        {OVERFLOW, AInt(0x1000000000000000ll), AInt(8), AInt(0)},
+        {OVERFLOW, AInt(-0x1000000000000000ll), AInt(-8), AInt(0)},
+        {OVERFLOW, AInt(0x800000000000000ll), AInt(0x10), AInt(0)},
+        {OVERFLOW, AInt(0x80000000ll), AInt(0x100000000ll), AInt(0)},
+        {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest), AInt(0)},
+        {OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kLowest), AInt(0)},
+        {OVERFLOW, AInt(1), AInt(1), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(1), AInt(-1), AInt(AInt::kLowest)},
+        {OVERFLOW, AInt(1), AInt(2), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(1), AInt(-2), AInt(AInt::kLowest)},
+        {OVERFLOW, AInt(1), AInt(10000), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(1), AInt(-10000), AInt(AInt::kLowest)},
+        {OVERFLOW, AInt(1), AInt(AInt::kHighest), AInt(AInt::kHighest)},
+        {OVERFLOW, AInt(1), AInt(AInt::kLowest), AInt(AInt::kLowest)},
+        {OVERFLOW, AInt(1), AInt(AInt::kHighest), AInt(1)},
+        {OVERFLOW, AInt(1), AInt(AInt::kLowest), AInt(-1)},
+    }));
+
 TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
 
 }  // namespace
diff --git a/src/tint/utils/compiler_macros.h b/src/tint/utils/compiler_macros.h
index ada138e..34965c6 100644
--- a/src/tint/utils/compiler_macros.h
+++ b/src/tint/utils/compiler_macros.h
@@ -20,23 +20,63 @@
 #define TINT_REQUIRE_SEMICOLON static_assert(true)
 
 #if defined(_MSC_VER)
-#define TINT_WARNING_UNREACHABLE_CODE 4702
-#define TINT_WARNING_CONSTANT_OVERFLOW 4756
+////////////////////////////////////////////////////////////////////////////////
+// MSVC
+////////////////////////////////////////////////////////////////////////////////
+#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW __pragma(warning(disable : 4756))
+#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED /* currently no-op */
+#define TINT_DISABLE_WARNING_UNREACHABLE_CODE __pragma(warning(disable : 4702))
 
 // clang-format off
-#define TINT_BEGIN_DISABLE_WARNING(name)                        \
-    __pragma(warning(push))                                     \
-    __pragma(warning(disable:TINT_CONCAT(TINT_WARNING_, name))) \
+#define TINT_BEGIN_DISABLE_WARNING(name)     \
+    __pragma(warning(push))                  \
+    TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
     TINT_REQUIRE_SEMICOLON
-#define TINT_END_DISABLE_WARNING(name)                          \
-    __pragma(warning(pop))                                      \
+#define TINT_END_DISABLE_WARNING(name)       \
+    __pragma(warning(pop))                   \
+    TINT_REQUIRE_SEMICOLON
+// clang-format on
+#elif defined(__clang__)
+////////////////////////////////////////////////////////////////////////////////
+// Clang
+////////////////////////////////////////////////////////////////////////////////
+#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW   /* currently no-op */
+#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED /* currently no-op */
+#define TINT_DISABLE_WARNING_UNREACHABLE_CODE    /* currently no-op */
+
+// clang-format off
+#define TINT_BEGIN_DISABLE_WARNING(name)     \
+    _Pragma("clang diagnostic push")         \
+    TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
+    TINT_REQUIRE_SEMICOLON
+#define TINT_END_DISABLE_WARNING(name)       \
+    _Pragma("clang diagnostic pop")          \
+    TINT_REQUIRE_SEMICOLON
+// clang-format on
+#elif defined(__GNUC__)
+////////////////////////////////////////////////////////////////////////////////
+// GCC
+////////////////////////////////////////////////////////////////////////////////
+#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW /* currently no-op */
+#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED \
+    _Pragma("GCC diagnostic ignored \"-Wmaybe-uninitialized\"")
+#define TINT_DISABLE_WARNING_UNREACHABLE_CODE /* currently no-op */
+
+// clang-format off
+#define TINT_BEGIN_DISABLE_WARNING(name)     \
+    _Pragma("GCC diagnostic push")           \
+    TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
+    TINT_REQUIRE_SEMICOLON
+#define TINT_END_DISABLE_WARNING(name)       \
+    _Pragma("GCC diagnostic pop")            \
     TINT_REQUIRE_SEMICOLON
 // clang-format on
 #else
-// clang-format off
+////////////////////////////////////////////////////////////////////////////////
+// Other
+////////////////////////////////////////////////////////////////////////////////
 #define TINT_BEGIN_DISABLE_WARNING(name) TINT_REQUIRE_SEMICOLON
 #define TINT_END_DISABLE_WARNING(name) TINT_REQUIRE_SEMICOLON
-// clang-format on
-#endif  // defined(_MSC_VER)
+#endif
 
 #endif  // SRC_TINT_UTILS_COMPILER_MACROS_H_