blob: 89757522853f7411dac728f1dd4407e330dbd4cc [file] [log] [blame]
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_
#define SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_
#include <limits>
#include <optional>
#include <string>
#include <utility>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/test_helper.h"
namespace tint::resolver {
template <typename T>
inline const auto kPiOver2 = T(UnwrapNumber<T>(1.57079632679489661923));
template <typename T>
inline const auto kPiOver4 = T(UnwrapNumber<T>(0.785398163397448309616));
template <typename T>
inline const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846));
/// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args
inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) {
Switch(
c->Type(), //
[&](const sem::AbstractInt*) { args.values.Push(c->As<AInt>()); },
[&](const sem::AbstractFloat*) { args.values.Push(c->As<AFloat>()); },
[&](const sem::Bool*) { args.values.Push(c->As<bool>()); },
[&](const sem::I32*) { args.values.Push(c->As<i32>()); },
[&](const sem::U32*) { args.values.Push(c->As<u32>()); },
[&](const sem::F32*) { args.values.Push(c->As<f32>()); },
[&](const sem::F16*) { args.values.Push(c->As<f16>()); },
[&](Default) {
size_t i = 0;
while (auto* child = c->Index(i++)) {
CollectScalarArgs(child, args);
}
});
}
/// Walks the sem::Constant @p c, returning all the inner-most scalar values.
inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) {
builder::ScalarArgs out;
CollectScalarArgs(c, out);
return out;
}
template <typename T>
inline auto Abs(const Number<T>& v) {
if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
return v;
} else {
return Number<T>(std::abs(v));
}
}
/// Flags that can be passed to CheckConstant()
struct CheckConstantFlags {
/// Expected value may be positive or negative
bool pos_or_neg = false;
/// Expected value should be compared using EXPECT_FLOAT_EQ instead of EQ, or EXPECT_NEAR if
/// float_compare_epsilon is set.
bool float_compare = false;
/// Expected value should be compared using EXPECT_NEAR if float_compare is set.
std::optional<double> float_compare_epsilon;
};
/// CheckConstant checks that @p got_constant, the result value of
/// constant-evaluation is equal to @p expected_value.
/// @param got_constant the constant value evaluated by the resolver
/// @param expected_value the expected value for the test
/// @param flags optional flags for controlling the comparisons
inline void CheckConstant(const sem::Constant* got_constant,
const builder::ValueBase* expected_value,
CheckConstantFlags flags = {}) {
auto values_flat = ScalarArgsFrom(got_constant);
auto expected_values_flat = expected_value->Args();
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
auto& got_scalar = values_flat.values[i];
auto& expected_scalar = expected_values_flat.values[i];
std::visit(
[&](auto&& expected) {
using T = std::decay_t<decltype(expected)>;
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
auto got = std::get<T>(got_scalar);
if constexpr (std::is_same_v<bool, T>) {
EXPECT_EQ(got, expected);
} else if constexpr (IsFloatingPoint<T>) {
if (std::isnan(expected)) {
EXPECT_TRUE(std::isnan(got));
} else {
if (flags.pos_or_neg) {
got = Abs(got);
}
if (flags.float_compare) {
if (flags.float_compare_epsilon) {
EXPECT_NEAR(got, expected, *flags.float_compare_epsilon);
} else {
EXPECT_FLOAT_EQ(got, expected);
}
} else {
EXPECT_EQ(got, expected);
}
}
} else {
if (flags.pos_or_neg) {
auto got_abs = Abs(got);
EXPECT_EQ(got_abs, expected);
} else {
EXPECT_EQ(got, expected);
}
// Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(AInt(got), AInt(expected));
}
},
expected_scalar);
}
}
template <typename T>
inline constexpr auto Negate(const Number<T>& v) {
if constexpr (std::is_integral_v<T>) {
if constexpr (std::is_signed_v<T>) {
// For signed integrals, avoid C++ UB by not negating the smallest negative number. In
// WGSL, this operation is well defined to return the same value, see:
// https://gpuweb.github.io/gpuweb/wgsl/#arithmetic-expr.
if (v == std::numeric_limits<T>::min()) {
return v;
}
return -v;
} else {
// Allow negating unsigned values
using ST = std::make_signed_t<T>;
auto as_signed = Number<ST>{static_cast<ST>(v)};
return Number<T>{static_cast<T>(Negate(as_signed))};
}
} else {
// float case
return -v;
}
}
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
template <typename T>
inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// For signed integrals, avoid C++ UB by multiplying as unsigned
using UT = std::make_unsigned_t<T>;
return static_cast<Number<T>>(static_cast<UT>(v1) * static_cast<UT>(v2));
} else {
return static_cast<Number<T>>(v1 * v2);
}
}
TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
template <typename T>
inline constexpr Number<T> Add(Number<T> v1, Number<T> v2) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// For signed integrals, avoid C++ UB by adding as unsigned
using UT = std::make_unsigned_t<T>;
return static_cast<Number<T>>(static_cast<UT>(v1) + static_cast<UT>(v2));
} else {
return static_cast<Number<T>>(v1 + v2);
}
}
TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
// Concats any number of std::vectors
template <typename Vec, typename... Vecs>
[[nodiscard]] inline auto Concat(Vec&& v1, Vecs&&... vs) {
auto total_size = v1.size() + (vs.size() + ...);
v1.reserve(total_size);
(std::move(vs.begin(), vs.end(), std::back_inserter(v1)), ...);
return std::move(v1);
}
// Concats vectors `vs` into `v1`
template <typename Vec, typename... Vecs>
inline void ConcatInto(Vec& v1, Vecs&&... vs) {
auto total_size = v1.size() + (vs.size() + ...);
v1.reserve(total_size);
(std::move(vs.begin(), vs.end(), std::back_inserter(v1)), ...);
}
// Concats vectors `vs` into `v1` iff `condition` is true
template <bool condition, typename Vec, typename... Vecs>
inline void ConcatIntoIf([[maybe_unused]] Vec& v1, [[maybe_unused]] Vecs&&... vs) {
if constexpr (condition) {
ConcatInto(v1, std::forward<Vecs>(vs)...);
}
}
/// Returns the overflow error message for binary ops
template <typename NumberT>
inline std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) {
std::stringstream ss;
ss << std::setprecision(20);
ss << "'" << lhs.value << " " << op << " " << rhs.value << "' cannot be represented as '"
<< FriendlyName<NumberT>() << "'";
return ss.str();
}
/// Returns the overflow error message for conversions
template <typename VALUE_TY>
std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) {
std::stringstream ss;
ss << std::setprecision(20);
ss << "value " << value << " cannot be represented as "
<< "'" << target_ty << "'";
return ss.str();
}
/// Returns the overflow error message for exponentiation
template <typename NumberT>
std::string OverflowExpErrorMessage(std::string_view base, NumberT value) {
std::stringstream ss;
ss << std::setprecision(20);
ss << base << "^" << value << " cannot be represented as "
<< "'" << FriendlyName<NumberT>() << "'";
return ss.str();
}
using builder::IsValue;
using builder::Mat;
using builder::Val;
using builder::Value;
using builder::ValueBase;
using builder::Vec;
using Types = std::variant< //
Value<AInt>,
Value<AFloat>,
Value<u32>,
Value<i32>,
Value<f32>,
Value<f16>,
Value<bool>,
Value<builder::vec2<AInt>>,
Value<builder::vec2<AFloat>>,
Value<builder::vec2<u32>>,
Value<builder::vec2<i32>>,
Value<builder::vec2<f32>>,
Value<builder::vec2<f16>>,
Value<builder::vec2<bool>>,
Value<builder::vec3<AInt>>,
Value<builder::vec3<AFloat>>,
Value<builder::vec3<u32>>,
Value<builder::vec3<i32>>,
Value<builder::vec3<f32>>,
Value<builder::vec3<f16>>,
Value<builder::vec3<bool>>,
Value<builder::vec4<AInt>>,
Value<builder::vec4<AFloat>>,
Value<builder::vec4<u32>>,
Value<builder::vec4<i32>>,
Value<builder::vec4<f32>>,
Value<builder::vec4<f16>>,
Value<builder::vec4<bool>>,
Value<builder::mat2x2<AInt>>,
Value<builder::mat2x2<AFloat>>,
Value<builder::mat2x2<f32>>,
Value<builder::mat2x2<f16>>,
Value<builder::mat2x3<AInt>>,
Value<builder::mat2x3<AFloat>>,
Value<builder::mat2x3<f32>>,
Value<builder::mat2x3<f16>>,
Value<builder::mat3x2<AInt>>,
Value<builder::mat3x2<AFloat>>,
Value<builder::mat3x2<f32>>,
Value<builder::mat3x2<f16>>
//
>;
/// Returns the current Value<T> in the `types` variant as a `ValueBase` pointer to use the
/// polymorphic API. This trades longer compile times using std::variant for longer runtime via
/// virtual function calls.
template <typename ValueVariant>
inline const ValueBase* ToValueBase(const ValueVariant& types) {
return std::visit(
[](auto&& t) -> const ValueBase* { return static_cast<const ValueBase*>(&t); }, types);
}
/// Prints Types to ostream
inline std::ostream& operator<<(std::ostream& o, const Types& types) {
return ToValueBase(types)->Print(o);
}
// Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops
// traversing, and return Action::kStop; if the function returns Action::kContinue, it continues and
// returns Action::kContinue when done.
// TODO(amaiorano): Move to Constant.h?
enum class Action { kStop, kContinue };
template <typename Func>
inline Action ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
EXPECT_EQ(a->Type(), b->Type());
size_t i = 0;
while (true) {
auto* a_elem = a->Index(i);
if (!a_elem) {
break;
}
auto* b_elem = b->Index(i);
if (ForEachElemPair(a_elem, b_elem, f) == Action::kStop) {
return Action::kStop;
}
i++;
}
if (i == 0) {
return f(a, b);
}
return Action::kContinue;
}
/// Defines common bit value patterns for the input `NumberT` type used for testing.
template <typename NumberT>
struct BitValues {
/// The unwrapped number type
using T = UnwrapNumber<NumberT>;
/// The unsigned unwrapped number type
using UT = std::make_unsigned_t<T>;
/// Details
struct detail {
/// Unsigned type of `T`
using UT = std::make_unsigned_t<T>;
/// Size in bits of type T
static constexpr size_t NumBits = sizeof(T) * 8;
/// All bits set 1
static constexpr T All = T{~T{0}};
/// Only left-most bits set to 1, rest set to 0
static constexpr T LeftMost = static_cast<T>(UT{1} << (NumBits - 1u));
/// Only left-most bits set to 0, rest set to 1
static constexpr T AllButLeftMost = T{~LeftMost};
/// Only two left-most bits set to 1, rest set to 0
static constexpr T TwoLeftMost = static_cast<T>(UT{0b11} << (NumBits - 2u));
/// Only two left-most bits set to 0, rest set to 1
static constexpr T AllButTwoLeftMost = T{~TwoLeftMost};
/// Only right-most bit set to 1, rest set to 0
static constexpr T RightMost = T{1};
/// Only right-most bit set to 0, rest set to 1
static constexpr T AllButRightMost = T{~RightMost};
};
/// Size in bits of type NumberT
static inline const size_t NumBits = detail::NumBits;
/// All bits set 1
static inline const NumberT All = NumberT{detail::All};
/// Only left-most bits set to 1, rest set to 0
static inline const NumberT LeftMost = NumberT{detail::LeftMost};
/// Only left-most bits set to 0, rest set to 1
static inline const NumberT AllButLeftMost = NumberT{detail::AllButLeftMost};
/// Only two left-most bits set to 1, rest set to 0
static inline const NumberT TwoLeftMost = NumberT{detail::TwoLeftMost};
/// Only two left-most bits set to 0, rest set to 1
static inline const NumberT AllButTwoLeftMost = NumberT{detail::AllButTwoLeftMost};
/// Only right-most bit set to 1, rest set to 0
static inline const NumberT RightMost = NumberT{detail::RightMost};
/// Only right-most bit set to 0, rest set to 1
static inline const NumberT AllButRightMost = NumberT{detail::AllButRightMost};
/// Performs a left-shift of `val` by `shiftBy`, both of varying type cast to `T`.
/// @param val value to shift left
/// @param shiftBy number of bits to shift left by
/// @returns the shifted value
template <typename U, typename V>
static constexpr NumberT Lsh(U val, V shiftBy) {
return NumberT{static_cast<T>(static_cast<UT>(val) << static_cast<UT>(shiftBy))};
}
};
using ResolverConstEvalTest = ResolverTest;
} // namespace tint::resolver
#endif // SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_