tint/resolver: Split constant checking to utility
And add basic support for builtins returning structures.
Bug tint:1581
Change-Id: I67f987339b9a344e1915c69c9991803f0665305d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111242
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index bbca0d5..70176c7 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -106,19 +106,7 @@
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
- auto values_flat = ScalarArgsFrom(value);
- auto expected_values_flat = expected->Args();
- ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
- for (size_t i = 0; i < values_flat.values.Length(); ++i) {
- auto& a = values_flat.values[i];
- auto& b = expected_values_flat.values[i];
- EXPECT_EQ(a, b);
- if (expected->IsIntegral()) {
- // Check that the constant's integer doesn't contain unexpected
- // data in the MSBs that are outside of the bit-width of T.
- EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
- }
- }
+ CheckConstant(value, expected);
} else {
ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), c.expected.Failure().error);
diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc
index 7fbdf30..6c6d193 100644
--- a/src/tint/resolver/const_eval_builtin_test.cc
+++ b/src/tint/resolver/const_eval_builtin_test.cc
@@ -26,8 +26,9 @@
using resolver::operator<<;
struct Case {
- Case(utils::VectorRef<Types> in_args, Types expected_value)
- : args(std::move(in_args)), expected(Success{std::move(expected_value), false, false}) {}
+ Case(utils::VectorRef<Types> in_args, utils::VectorRef<Types> expected_values)
+ : args(std::move(in_args)),
+ expected(Success{std::move(expected_values), CheckConstantFlags{}}) {}
Case(utils::VectorRef<Types> in_args, std::string expected_err)
: args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {}
@@ -35,7 +36,7 @@
/// Expected value may be positive or negative
Case& PosOrNeg() {
Success s = expected.Get();
- s.pos_or_neg = true;
+ s.flags.pos_or_neg = true;
expected = s;
return *this;
}
@@ -43,15 +44,14 @@
/// Expected value should be compared using FLOAT_EQ instead of EQ
Case& FloatComp() {
Success s = expected.Get();
- s.float_compare = true;
+ s.flags.float_compare = true;
expected = s;
return *this;
}
struct Success {
- Types value;
- bool pos_or_neg = false;
- bool float_compare = false;
+ utils::Vector<Types, 2> values;
+ CheckConstantFlags flags;
};
struct Failure {
std::string error;
@@ -69,7 +69,20 @@
o << "expected: ";
if (c.expected) {
auto s = c.expected.Get();
- o << s.value << ", pos_or_neg: " << s.pos_or_neg;
+ if (s.values.Length() == 1) {
+ o << s.values[0];
+ } else {
+ o << "[";
+ for (auto& v : s.values) {
+ if (&v != &s.values[0]) {
+ o << ", ";
+ }
+ o << v;
+ }
+ o << "]";
+ }
+ o << ", pos_or_neg: " << s.flags.pos_or_neg;
+ o << ", float_compare: " << s.flags.float_compare;
} else {
o << "[ERROR: " << c.expected.Failure().error << "]";
}
@@ -80,7 +93,7 @@
/// Creates a Case with Values for args and result
static Case C(std::initializer_list<Types> args, Types result) {
- return Case{utils::Vector<Types, 8>{args}, std::move(result)};
+ return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{std::move(result)}};
}
/// Convenience overload that creates a Case with just scalars
@@ -91,7 +104,7 @@
}
Types result = Val(0_a);
std::visit([&](auto&& v) { result = Val(v); }, sresult);
- return Case{std::move(args), std::move(result)};
+ return Case{std::move(args), utils::Vector<Types, 2>{std::move(result)}};
}
/// Creates a Case with Values for args and expected error
@@ -127,9 +140,6 @@
if (c.expected) {
auto expected_case = c.expected.Get();
- auto* expected_expr = ToValueBase(expected_case.value)->Expr(*this);
- GlobalConst("E", expected_expr);
-
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
@@ -138,43 +148,19 @@
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
- auto* expected_sem = Sem().Get(expected_expr);
- const sem::Constant* expected_value = expected_sem->ConstantValue();
- ASSERT_NE(expected_value, nullptr);
- EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
-
- // @TODO(amaiorano): Rewrite using ScalarArgsFrom()
- ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) {
- std::visit(
- [&](auto&& ct_expected) {
- using T = typename std::decay_t<decltype(ct_expected)>::ElementType;
-
- auto v = a->As<T>();
- auto e = b->As<T>();
- if constexpr (std::is_same_v<bool, T>) {
- EXPECT_EQ(v, e);
- } else if constexpr (IsFloatingPoint<T>) {
- if (std::isnan(e)) {
- EXPECT_TRUE(std::isnan(v));
- } else {
- auto vf = (expected_case.pos_or_neg ? Abs(v) : v);
- if (expected_case.float_compare) {
- EXPECT_FLOAT_EQ(vf, e);
- } else {
- EXPECT_EQ(vf, e);
- }
- }
- } else {
- EXPECT_EQ((expected_case.pos_or_neg ? Abs(v) : v), e);
- // Check that the constant's integer doesn't contain unexpected
- // data in the MSBs that are outside of the bit-width of T.
- EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
- }
- },
- expected_case.value);
-
- return HasFailure() ? Action::kStop : Action::kContinue;
- });
+ if (value->Type()->Is<sem::Struct>()) {
+ // The result type of the constant-evaluated expression is a structure.
+ // Compare each of the fields individually.
+ for (size_t i = 0; i < expected_case.values.Length(); i++) {
+ CheckConstant(value->Index(i), ToValueBase(expected_case.values[i]),
+ expected_case.flags);
+ }
+ } else {
+ // Return type is not a structure. Just compare the single value
+ ASSERT_EQ(expected_case.values.Length(), 1u)
+ << "const-eval returned non-struct, but Case expected multiple values";
+ CheckConstant(value, ToValueBase(expected_case.values[0]), expected_case.flags);
+ }
} else {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), c.expected.Failure().error);
diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h
index cccbf76..908d7e2 100644
--- a/src/tint/resolver/const_eval_test.h
+++ b/src/tint/resolver/const_eval_test.h
@@ -62,6 +62,81 @@
}
template <typename T>
+inline auto Abs(const Number<T>& v) {
+ if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
+ return v;
+ } else {
+ return Number<T>(std::abs(v));
+ }
+}
+
+/// Flags that can be passed to CheckConstant()
+struct CheckConstantFlags {
+ /// Expected value may be positive or negative
+ bool pos_or_neg = false;
+ /// Expected value should be compared using FLOAT_EQ instead of EQ
+ bool float_compare = false;
+};
+
+/// CheckConstant checks that @p got_constant, the result value of
+/// constant-evaluation is equal to @p expected_value.
+/// @param got_constant the constant value evaluated by the resolver
+/// @param expected_value the expected value for the test
+/// @param flags optional flags for controlling the comparisons
+inline void CheckConstant(const sem::Constant* got_constant,
+ const builder::ValueBase* expected_value,
+ CheckConstantFlags flags = {}) {
+ auto values_flat = ScalarArgsFrom(got_constant);
+ auto expected_values_flat = expected_value->Args();
+ ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
+ for (size_t i = 0; i < values_flat.values.Length(); ++i) {
+ auto& got_scalar = values_flat.values[i];
+ auto& expected_scalar = expected_values_flat.values[i];
+ std::visit(
+ [&](auto&& expected) {
+ using T = std::decay_t<decltype(expected)>;
+
+ ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
+ auto got = std::get<T>(got_scalar);
+
+ if constexpr (std::is_same_v<bool, T>) {
+ EXPECT_EQ(got, expected);
+ } else if constexpr (IsFloatingPoint<T>) {
+ if (std::isnan(expected)) {
+ EXPECT_TRUE(std::isnan(got));
+ } else {
+ if (flags.pos_or_neg) {
+ auto got_abs = Abs(got);
+ if (flags.float_compare) {
+ EXPECT_FLOAT_EQ(got_abs, expected);
+ } else {
+ EXPECT_EQ(got_abs, expected);
+ }
+ } else {
+ if (flags.float_compare) {
+ EXPECT_FLOAT_EQ(got, expected);
+ } else {
+ EXPECT_EQ(got, expected);
+ }
+ }
+ }
+ } else {
+ if (flags.pos_or_neg) {
+ auto got_abs = Abs(got);
+ EXPECT_EQ(got_abs, expected);
+ } else {
+ EXPECT_EQ(got, expected);
+ }
+ // Check that the constant's integer doesn't contain unexpected
+ // data in the MSBs that are outside of the bit-width of T.
+ EXPECT_EQ(AInt(got), AInt(expected));
+ }
+ },
+ expected_scalar);
+ }
+}
+
+template <typename T>
inline constexpr auto Negate(const Number<T>& v) {
if constexpr (std::is_integral_v<T>) {
if constexpr (std::is_signed_v<T>) {
@@ -85,15 +160,6 @@
}
}
-template <typename T>
-inline auto Abs(const Number<T>& v) {
- if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
- return v;
- } else {
- return Number<T>(std::abs(v));
- }
-}
-
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
template <typename T>
inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {