tint/resolver: Split constant checking to utility

And add basic support for builtins returning structures.

Bug tint:1581

Change-Id: I67f987339b9a344e1915c69c9991803f0665305d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111242
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index bbca0d5..70176c7 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -106,19 +106,7 @@
         ASSERT_NE(value, nullptr);
         EXPECT_TYPE(value->Type(), sem->Type());
 
-        auto values_flat = ScalarArgsFrom(value);
-        auto expected_values_flat = expected->Args();
-        ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
-        for (size_t i = 0; i < values_flat.values.Length(); ++i) {
-            auto& a = values_flat.values[i];
-            auto& b = expected_values_flat.values[i];
-            EXPECT_EQ(a, b);
-            if (expected->IsIntegral()) {
-                // Check that the constant's integer doesn't contain unexpected
-                // data in the MSBs that are outside of the bit-width of T.
-                EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
-            }
-        }
+        CheckConstant(value, expected);
     } else {
         ASSERT_FALSE(r()->Resolve());
         EXPECT_EQ(r()->error(), c.expected.Failure().error);
diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc
index 7fbdf30..6c6d193 100644
--- a/src/tint/resolver/const_eval_builtin_test.cc
+++ b/src/tint/resolver/const_eval_builtin_test.cc
@@ -26,8 +26,9 @@
 using resolver::operator<<;
 
 struct Case {
-    Case(utils::VectorRef<Types> in_args, Types expected_value)
-        : args(std::move(in_args)), expected(Success{std::move(expected_value), false, false}) {}
+    Case(utils::VectorRef<Types> in_args, utils::VectorRef<Types> expected_values)
+        : args(std::move(in_args)),
+          expected(Success{std::move(expected_values), CheckConstantFlags{}}) {}
 
     Case(utils::VectorRef<Types> in_args, std::string expected_err)
         : args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {}
@@ -35,7 +36,7 @@
     /// Expected value may be positive or negative
     Case& PosOrNeg() {
         Success s = expected.Get();
-        s.pos_or_neg = true;
+        s.flags.pos_or_neg = true;
         expected = s;
         return *this;
     }
@@ -43,15 +44,14 @@
     /// Expected value should be compared using FLOAT_EQ instead of EQ
     Case& FloatComp() {
         Success s = expected.Get();
-        s.float_compare = true;
+        s.flags.float_compare = true;
         expected = s;
         return *this;
     }
 
     struct Success {
-        Types value;
-        bool pos_or_neg = false;
-        bool float_compare = false;
+        utils::Vector<Types, 2> values;
+        CheckConstantFlags flags;
     };
     struct Failure {
         std::string error;
@@ -69,7 +69,20 @@
     o << "expected: ";
     if (c.expected) {
         auto s = c.expected.Get();
-        o << s.value << ", pos_or_neg: " << s.pos_or_neg;
+        if (s.values.Length() == 1) {
+            o << s.values[0];
+        } else {
+            o << "[";
+            for (auto& v : s.values) {
+                if (&v != &s.values[0]) {
+                    o << ", ";
+                }
+                o << v;
+            }
+            o << "]";
+        }
+        o << ", pos_or_neg: " << s.flags.pos_or_neg;
+        o << ", float_compare: " << s.flags.float_compare;
     } else {
         o << "[ERROR: " << c.expected.Failure().error << "]";
     }
@@ -80,7 +93,7 @@
 
 /// Creates a Case with Values for args and result
 static Case C(std::initializer_list<Types> args, Types result) {
-    return Case{utils::Vector<Types, 8>{args}, std::move(result)};
+    return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{std::move(result)}};
 }
 
 /// Convenience overload that creates a Case with just scalars
@@ -91,7 +104,7 @@
     }
     Types result = Val(0_a);
     std::visit([&](auto&& v) { result = Val(v); }, sresult);
-    return Case{std::move(args), std::move(result)};
+    return Case{std::move(args), utils::Vector<Types, 2>{std::move(result)}};
 }
 
 /// Creates a Case with Values for args and expected error
@@ -127,9 +140,6 @@
     if (c.expected) {
         auto expected_case = c.expected.Get();
 
-        auto* expected_expr = ToValueBase(expected_case.value)->Expr(*this);
-        GlobalConst("E", expected_expr);
-
         ASSERT_TRUE(r()->Resolve()) << r()->error();
 
         auto* sem = Sem().Get(expr);
@@ -138,43 +148,19 @@
         ASSERT_NE(value, nullptr);
         EXPECT_TYPE(value->Type(), sem->Type());
 
-        auto* expected_sem = Sem().Get(expected_expr);
-        const sem::Constant* expected_value = expected_sem->ConstantValue();
-        ASSERT_NE(expected_value, nullptr);
-        EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
-
-        // @TODO(amaiorano): Rewrite using ScalarArgsFrom()
-        ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) {
-            std::visit(
-                [&](auto&& ct_expected) {
-                    using T = typename std::decay_t<decltype(ct_expected)>::ElementType;
-
-                    auto v = a->As<T>();
-                    auto e = b->As<T>();
-                    if constexpr (std::is_same_v<bool, T>) {
-                        EXPECT_EQ(v, e);
-                    } else if constexpr (IsFloatingPoint<T>) {
-                        if (std::isnan(e)) {
-                            EXPECT_TRUE(std::isnan(v));
-                        } else {
-                            auto vf = (expected_case.pos_or_neg ? Abs(v) : v);
-                            if (expected_case.float_compare) {
-                                EXPECT_FLOAT_EQ(vf, e);
-                            } else {
-                                EXPECT_EQ(vf, e);
-                            }
-                        }
-                    } else {
-                        EXPECT_EQ((expected_case.pos_or_neg ? Abs(v) : v), e);
-                        // Check that the constant's integer doesn't contain unexpected
-                        // data in the MSBs that are outside of the bit-width of T.
-                        EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
-                    }
-                },
-                expected_case.value);
-
-            return HasFailure() ? Action::kStop : Action::kContinue;
-        });
+        if (value->Type()->Is<sem::Struct>()) {
+            // The result type of the constant-evaluated expression is a structure.
+            // Compare each of the fields individually.
+            for (size_t i = 0; i < expected_case.values.Length(); i++) {
+                CheckConstant(value->Index(i), ToValueBase(expected_case.values[i]),
+                              expected_case.flags);
+            }
+        } else {
+            // Return type is not a structure. Just compare the single value
+            ASSERT_EQ(expected_case.values.Length(), 1u)
+                << "const-eval returned non-struct, but Case expected multiple values";
+            CheckConstant(value, ToValueBase(expected_case.values[0]), expected_case.flags);
+        }
     } else {
         EXPECT_FALSE(r()->Resolve());
         EXPECT_EQ(r()->error(), c.expected.Failure().error);
diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h
index cccbf76..908d7e2 100644
--- a/src/tint/resolver/const_eval_test.h
+++ b/src/tint/resolver/const_eval_test.h
@@ -62,6 +62,81 @@
 }
 
 template <typename T>
+inline auto Abs(const Number<T>& v) {
+    if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
+        return v;
+    } else {
+        return Number<T>(std::abs(v));
+    }
+}
+
+/// Flags that can be passed to CheckConstant()
+struct CheckConstantFlags {
+    /// Expected value may be positive or negative
+    bool pos_or_neg = false;
+    /// Expected value should be compared using FLOAT_EQ instead of EQ
+    bool float_compare = false;
+};
+
+/// CheckConstant checks that @p got_constant, the result value of
+/// constant-evaluation is equal to @p expected_value.
+/// @param got_constant the constant value evaluated by the resolver
+/// @param expected_value the expected value for the test
+/// @param flags optional flags for controlling the comparisons
+inline void CheckConstant(const sem::Constant* got_constant,
+                          const builder::ValueBase* expected_value,
+                          CheckConstantFlags flags = {}) {
+    auto values_flat = ScalarArgsFrom(got_constant);
+    auto expected_values_flat = expected_value->Args();
+    ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
+    for (size_t i = 0; i < values_flat.values.Length(); ++i) {
+        auto& got_scalar = values_flat.values[i];
+        auto& expected_scalar = expected_values_flat.values[i];
+        std::visit(
+            [&](auto&& expected) {
+                using T = std::decay_t<decltype(expected)>;
+
+                ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
+                auto got = std::get<T>(got_scalar);
+
+                if constexpr (std::is_same_v<bool, T>) {
+                    EXPECT_EQ(got, expected);
+                } else if constexpr (IsFloatingPoint<T>) {
+                    if (std::isnan(expected)) {
+                        EXPECT_TRUE(std::isnan(got));
+                    } else {
+                        if (flags.pos_or_neg) {
+                            auto got_abs = Abs(got);
+                            if (flags.float_compare) {
+                                EXPECT_FLOAT_EQ(got_abs, expected);
+                            } else {
+                                EXPECT_EQ(got_abs, expected);
+                            }
+                        } else {
+                            if (flags.float_compare) {
+                                EXPECT_FLOAT_EQ(got, expected);
+                            } else {
+                                EXPECT_EQ(got, expected);
+                            }
+                        }
+                    }
+                } else {
+                    if (flags.pos_or_neg) {
+                        auto got_abs = Abs(got);
+                        EXPECT_EQ(got_abs, expected);
+                    } else {
+                        EXPECT_EQ(got, expected);
+                    }
+                    // Check that the constant's integer doesn't contain unexpected
+                    // data in the MSBs that are outside of the bit-width of T.
+                    EXPECT_EQ(AInt(got), AInt(expected));
+                }
+            },
+            expected_scalar);
+    }
+}
+
+template <typename T>
 inline constexpr auto Negate(const Number<T>& v) {
     if constexpr (std::is_integral_v<T>) {
         if constexpr (std::is_signed_v<T>) {
@@ -85,15 +160,6 @@
     }
 }
 
-template <typename T>
-inline auto Abs(const Number<T>& v) {
-    if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
-        return v;
-    } else {
-        return Number<T>(std::abs(v));
-    }
-}
-
 TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
 template <typename T>
 inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {