| // Copyright 2022 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #ifndef SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_ |
| #define SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_ |
| |
| #include <limits> |
| #include <optional> |
| #include <string> |
| #include <utility> |
| |
| #include "gmock/gmock.h" |
| #include "gtest/gtest.h" |
| #include "src/tint/resolver/resolver_test_helper.h" |
| #include "src/tint/sem/test_helper.h" |
| |
| namespace tint::resolver { |
| |
| template <typename T> |
| inline const auto kPiOver2 = T(UnwrapNumber<T>(1.57079632679489661923)); |
| |
| template <typename T> |
| inline const auto kPiOver4 = T(UnwrapNumber<T>(0.785398163397448309616)); |
| |
| template <typename T> |
| inline const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846)); |
| |
| /// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args |
| template <size_t N> |
| inline void CollectScalars(const sem::Constant* c, utils::Vector<builder::Scalar, N>& scalars) { |
| Switch( |
| c->Type(), // |
| [&](const sem::AbstractInt*) { scalars.Push(c->As<AInt>()); }, |
| [&](const sem::AbstractFloat*) { scalars.Push(c->As<AFloat>()); }, |
| [&](const sem::Bool*) { scalars.Push(c->As<bool>()); }, |
| [&](const sem::I32*) { scalars.Push(c->As<i32>()); }, |
| [&](const sem::U32*) { scalars.Push(c->As<u32>()); }, |
| [&](const sem::F32*) { scalars.Push(c->As<f32>()); }, |
| [&](const sem::F16*) { scalars.Push(c->As<f16>()); }, |
| [&](Default) { |
| size_t i = 0; |
| while (auto* child = c->Index(i++)) { |
| CollectScalars(child, scalars); |
| } |
| }); |
| } |
| |
| /// Walks the sem::Constant @p c, returning all the inner-most scalar values. |
| inline utils::Vector<builder::Scalar, 16> ScalarsFrom(const sem::Constant* c) { |
| utils::Vector<builder::Scalar, 16> out; |
| CollectScalars(c, out); |
| return out; |
| } |
| |
| template <typename T> |
| inline auto Abs(const Number<T>& v) { |
| if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) { |
| return v; |
| } else { |
| return Number<T>(std::abs(v)); |
| } |
| } |
| |
| /// Flags that can be passed to CheckConstant() |
| struct CheckConstantFlags { |
| /// Expected value may be positive or negative |
| bool pos_or_neg = false; |
| /// Expected value should be compared using EXPECT_FLOAT_EQ instead of EQ, or EXPECT_NEAR if |
| /// float_compare_epsilon is set. |
| bool float_compare = false; |
| /// Expected value should be compared using EXPECT_NEAR if float_compare is set. |
| std::optional<double> float_compare_epsilon; |
| }; |
| |
| /// CheckConstant checks that @p got_constant, the result value of |
| /// constant-evaluation is equal to @p expected_value. |
| /// @param got_constant the constant value evaluated by the resolver |
| /// @param expected_value the expected value for the test |
| /// @param flags optional flags for controlling the comparisons |
| inline void CheckConstant(const sem::Constant* got_constant, |
| const builder::Value& expected_value, |
| CheckConstantFlags flags = {}) { |
| auto values_flat = ScalarsFrom(got_constant); |
| auto expected_values_flat = expected_value.args; |
| ASSERT_EQ(values_flat.Length(), expected_values_flat.Length()); |
| for (size_t i = 0; i < values_flat.Length(); ++i) { |
| auto& got_scalar = values_flat[i]; |
| auto& expected_scalar = expected_values_flat[i]; |
| std::visit( |
| [&](const auto& expected) { |
| using T = std::decay_t<decltype(expected)>; |
| |
| ASSERT_TRUE(std::holds_alternative<T>(got_scalar)); |
| auto got = std::get<T>(got_scalar); |
| |
| if constexpr (std::is_same_v<bool, T>) { |
| EXPECT_EQ(got, expected) << "index: " << i; |
| } else if constexpr (IsFloatingPoint<T>) { |
| if (std::isnan(expected)) { |
| EXPECT_TRUE(std::isnan(got)) << "index: " << i; |
| } else { |
| if (flags.pos_or_neg) { |
| got = Abs(got); |
| } |
| if (flags.float_compare) { |
| if (flags.float_compare_epsilon) { |
| EXPECT_NEAR(got, expected, *flags.float_compare_epsilon) |
| << "index: " << i; |
| } else { |
| EXPECT_FLOAT_EQ(got, expected) << "index: " << i; |
| } |
| } else { |
| EXPECT_EQ(got, expected) << "index: " << i; |
| } |
| } |
| } else { |
| if (flags.pos_or_neg) { |
| got = Abs(got); |
| } |
| EXPECT_EQ(got, expected) << "index: " << i; |
| |
| // Check that the constant's integer doesn't contain unexpected |
| // data in the MSBs that are outside of the bit-width of T. |
| EXPECT_EQ(AInt(got), AInt(expected)) << "index: " << i; |
| } |
| }, |
| expected_scalar); |
| } |
| } |
| |
| template <typename T> |
| inline constexpr auto Negate(const Number<T>& v) { |
| if constexpr (std::is_integral_v<T>) { |
| if constexpr (std::is_signed_v<T>) { |
| // For signed integrals, avoid C++ UB by not negating the smallest negative number. In |
| // WGSL, this operation is well defined to return the same value, see: |
| // https://gpuweb.github.io/gpuweb/wgsl/#arithmetic-expr. |
| if (v == std::numeric_limits<T>::min()) { |
| return v; |
| } |
| return -v; |
| |
| } else { |
| // Allow negating unsigned values |
| using ST = std::make_signed_t<T>; |
| auto as_signed = Number<ST>{static_cast<ST>(v)}; |
| return Number<T>{static_cast<T>(Negate(as_signed))}; |
| } |
| } else { |
| // float case |
| return -v; |
| } |
| } |
| |
| TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW); |
| template <typename T> |
| inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) { |
| if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) { |
| // For signed integrals, avoid C++ UB by multiplying as unsigned |
| using UT = std::make_unsigned_t<T>; |
| return static_cast<Number<T>>(static_cast<UT>(v1) * static_cast<UT>(v2)); |
| } else { |
| return static_cast<Number<T>>(v1 * v2); |
| } |
| } |
| TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW); |
| |
| TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW); |
| template <typename T> |
| inline constexpr Number<T> Add(Number<T> v1, Number<T> v2) { |
| if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) { |
| // For signed integrals, avoid C++ UB by adding as unsigned |
| using UT = std::make_unsigned_t<T>; |
| return static_cast<Number<T>>(static_cast<UT>(v1) + static_cast<UT>(v2)); |
| } else { |
| return static_cast<Number<T>>(v1 + v2); |
| } |
| } |
| TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW); |
| |
| // Concats any number of std::vectors |
| template <typename Vec, typename... Vecs> |
| [[nodiscard]] inline auto Concat(Vec&& v1, Vecs&&... vs) { |
| auto total_size = v1.size() + (vs.size() + ...); |
| v1.reserve(total_size); |
| (std::move(vs.begin(), vs.end(), std::back_inserter(v1)), ...); |
| return std::move(v1); |
| } |
| |
| // Concats vectors `vs` into `v1` |
| template <typename Vec, typename... Vecs> |
| inline void ConcatInto(Vec& v1, Vecs&&... vs) { |
| auto total_size = v1.size() + (vs.size() + ...); |
| v1.reserve(total_size); |
| (std::move(vs.begin(), vs.end(), std::back_inserter(v1)), ...); |
| } |
| |
| // Concats vectors `vs` into `v1` iff `condition` is true |
| template <bool condition, typename Vec, typename... Vecs> |
| inline void ConcatIntoIf([[maybe_unused]] Vec& v1, [[maybe_unused]] Vecs&&... vs) { |
| if constexpr (condition) { |
| ConcatInto(v1, std::forward<Vecs>(vs)...); |
| } |
| } |
| |
| /// Returns the overflow error message for binary ops |
| template <typename NumberT> |
| inline std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) { |
| std::stringstream ss; |
| ss << std::setprecision(20); |
| ss << "'" << lhs.value << " " << op << " " << rhs.value << "' cannot be represented as '" |
| << FriendlyName<NumberT>() << "'"; |
| return ss.str(); |
| } |
| |
| /// Returns the overflow error message for conversions |
| template <typename VALUE_TY> |
| std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) { |
| std::stringstream ss; |
| ss << std::setprecision(20); |
| ss << "value " << value << " cannot be represented as " |
| << "'" << target_ty << "'"; |
| return ss.str(); |
| } |
| |
| /// Returns the overflow error message for exponentiation |
| template <typename NumberT> |
| std::string OverflowExpErrorMessage(std::string_view base, NumberT value) { |
| std::stringstream ss; |
| ss << std::setprecision(20); |
| ss << base << "^" << value << " cannot be represented as " |
| << "'" << FriendlyName<NumberT>() << "'"; |
| return ss.str(); |
| } |
| |
| using builder::IsValue; |
| using builder::Mat; |
| using builder::Val; |
| using builder::Value; |
| using builder::Vec; |
| |
| // Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops |
| // traversing, and return Action::kStop; if the function returns Action::kContinue, it continues and |
| // returns Action::kContinue when done. |
| // TODO(amaiorano): Move to Constant.h? |
| enum class Action { kStop, kContinue }; |
| template <typename Func> |
| inline Action ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) { |
| EXPECT_EQ(a->Type(), b->Type()); |
| size_t i = 0; |
| while (true) { |
| auto* a_elem = a->Index(i); |
| if (!a_elem) { |
| break; |
| } |
| auto* b_elem = b->Index(i); |
| if (ForEachElemPair(a_elem, b_elem, f) == Action::kStop) { |
| return Action::kStop; |
| } |
| i++; |
| } |
| if (i == 0) { |
| return f(a, b); |
| } |
| return Action::kContinue; |
| } |
| |
| /// Defines common bit value patterns for the input `NumberT` type used for testing. |
| template <typename NumberT> |
| struct BitValues { |
| /// The unwrapped number type |
| using T = UnwrapNumber<NumberT>; |
| /// The unsigned unwrapped number type |
| using UT = std::make_unsigned_t<T>; |
| /// Details |
| struct detail { |
| /// Unsigned type of `T` |
| using UT = std::make_unsigned_t<T>; |
| /// Size in bits of type T |
| static constexpr size_t NumBits = sizeof(T) * 8; |
| /// All bits set 1 |
| static constexpr T All = T{~T{0}}; |
| /// Only left-most bits set to 1, rest set to 0 |
| static constexpr T LeftMost = static_cast<T>(UT{1} << (NumBits - 1u)); |
| /// Only left-most bits set to 0, rest set to 1 |
| static constexpr T AllButLeftMost = T{~LeftMost}; |
| /// Only two left-most bits set to 1, rest set to 0 |
| static constexpr T TwoLeftMost = static_cast<T>(UT{0b11} << (NumBits - 2u)); |
| /// Only two left-most bits set to 0, rest set to 1 |
| static constexpr T AllButTwoLeftMost = T{~TwoLeftMost}; |
| /// Only right-most bit set to 1, rest set to 0 |
| static constexpr T RightMost = T{1}; |
| /// Only right-most bit set to 0, rest set to 1 |
| static constexpr T AllButRightMost = T{~RightMost}; |
| }; |
| |
| /// Size in bits of type NumberT |
| static inline const size_t NumBits = detail::NumBits; |
| /// All bits set 1 |
| static inline const NumberT All = NumberT{detail::All}; |
| /// Only left-most bits set to 1, rest set to 0 |
| static inline const NumberT LeftMost = NumberT{detail::LeftMost}; |
| /// Only left-most bits set to 0, rest set to 1 |
| static inline const NumberT AllButLeftMost = NumberT{detail::AllButLeftMost}; |
| /// Only two left-most bits set to 1, rest set to 0 |
| static inline const NumberT TwoLeftMost = NumberT{detail::TwoLeftMost}; |
| /// Only two left-most bits set to 0, rest set to 1 |
| static inline const NumberT AllButTwoLeftMost = NumberT{detail::AllButTwoLeftMost}; |
| /// Only right-most bit set to 1, rest set to 0 |
| static inline const NumberT RightMost = NumberT{detail::RightMost}; |
| /// Only right-most bit set to 0, rest set to 1 |
| static inline const NumberT AllButRightMost = NumberT{detail::AllButRightMost}; |
| |
| /// Performs a left-shift of `val` by `shiftBy`, both of varying type cast to `T`. |
| /// @param val value to shift left |
| /// @param shiftBy number of bits to shift left by |
| /// @returns the shifted value |
| template <typename U, typename V> |
| static constexpr NumberT Lsh(U val, V shiftBy) { |
| return NumberT{static_cast<T>(static_cast<UT>(val) << static_cast<UT>(shiftBy))}; |
| } |
| }; |
| |
| using ResolverConstEvalTest = ResolverTest; |
| |
| } // namespace tint::resolver |
| |
| #endif // SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_ |