tint: optimize compile time for const_eval_*_test files

The reason for slow compile times is because the very large variants of
builder::Value<T>s combined with the many std::visits over these
variants result in many combinatorial instantiations of the visit
callbacks.

To address this, I added a polymorphic base class ValueBase to Value<T>,
and replaced most of the std::visit-based compile time code with runtime
virtual calls. For the two heaviest users of std::visit over the large
variants, compiles times dropped more than half (clang-10, debug):

const_eval_binary_op_test.cc: 19.079s to 7.736s
const_eval_unary_op_test.cc: 10.021s to 4.789s

Bug: tint:1711
Change-Id: Iba05e6ae1004ef0814250e2a8ea50aa2b26b85f2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105782
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc
index 18b28da..5f43499 100644
--- a/src/tint/resolver/const_eval_binary_op_test.cc
+++ b/src/tint/resolver/const_eval_binary_op_test.cc
@@ -54,47 +54,39 @@
     auto op = std::get<0>(GetParam());
     auto& c = std::get<1>(GetParam());
 
-    std::visit(
-        [&](auto&& expected) {
-            using T = typename std::decay_t<decltype(expected)>::ElementType;
-            if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
-                if (c.overflow) {
-                    // Overflow is not allowed for abstract types. This is tested separately.
-                    return;
-                }
-            }
+    auto* expected = ToValueBase(c.expected);
+    if (expected->IsAbstract() && c.overflow) {
+        // Overflow is not allowed for abstract types. This is tested separately.
+        return;
+    }
 
-            auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
-            auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
-            auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
+    auto* lhs = ToValueBase(c.lhs);
+    auto* rhs = ToValueBase(c.rhs);
 
-            GlobalConst("C", expr);
-            auto* expected_expr = expected.Expr(*this);
-            GlobalConst("E", expected_expr);
-            ASSERT_TRUE(r()->Resolve()) << r()->error();
+    auto* lhs_expr = lhs->Expr(*this);
+    auto* rhs_expr = rhs->Expr(*this);
+    auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
+    GlobalConst("C", expr);
+    ASSERT_TRUE(r()->Resolve()) << r()->error();
 
-            auto* sem = Sem().Get(expr);
-            const sem::Constant* value = sem->ConstantValue();
-            ASSERT_NE(value, nullptr);
-            EXPECT_TYPE(value->Type(), sem->Type());
+    auto* sem = Sem().Get(expr);
+    const sem::Constant* value = sem->ConstantValue();
+    ASSERT_NE(value, nullptr);
+    EXPECT_TYPE(value->Type(), sem->Type());
 
-            auto* expected_sem = Sem().Get(expected_expr);
-            const sem::Constant* expected_value = expected_sem->ConstantValue();
-            ASSERT_NE(expected_value, nullptr);
-            EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
-
-            ForEachElemPair(value, expected_value,
-                            [&](const sem::Constant* a, const sem::Constant* b) {
-                                EXPECT_EQ(a->As<T>(), b->As<T>());
-                                if constexpr (IsIntegral<T>) {
-                                    // Check that the constant's integer doesn't contain unexpected
-                                    // data in the MSBs that are outside of the bit-width of T.
-                                    EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
-                                }
-                                return HasFailure() ? Action::kStop : Action::kContinue;
-                            });
-        },
-        c.expected);
+    auto values_flat = ScalarArgsFrom(value);
+    auto expected_values_flat = expected->Args();
+    ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
+    for (size_t i = 0; i < values_flat.values.Length(); ++i) {
+        auto& a = values_flat.values[i];
+        auto& b = expected_values_flat.values[i];
+        EXPECT_EQ(a, b);
+        if (expected->IsIntegral()) {
+            // Check that the constant's integer doesn't contain unexpected
+            // data in the MSBs that are outside of the bit-width of T.
+            EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
+        }
+    }
 }
 
 INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs,
@@ -658,21 +650,15 @@
 TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) {
     Enable(ast::Extension::kF16);
     auto& c = GetParam();
-    auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
-    auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
+    auto* lhs = ToValueBase(c.lhs);
+    auto* rhs = ToValueBase(c.rhs);
+    auto* lhs_expr = lhs->Expr(*this);
+    auto* rhs_expr = rhs->Expr(*this);
     auto* expr = create<ast::BinaryExpression>(Source{{1, 1}}, c.op, lhs_expr, rhs_expr);
     GlobalConst("C", expr);
     ASSERT_FALSE(r()->Resolve());
-
-    std::string type_name = std::visit(
-        [&](auto&& value) {
-            using ValueType = std::decay_t<decltype(value)>;
-            return builder::FriendlyName<ValueType>();
-        },
-        c.lhs);
-
     EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '"));
-    EXPECT_THAT(r()->error(), HasSubstr("' cannot be represented as '" + type_name + "'"));
+    EXPECT_THAT(r()->error(), HasSubstr("' cannot be represented as '" + lhs->TypeName() + "'"));
 }
 INSTANTIATE_TEST_SUITE_P(
     Test,
@@ -854,10 +840,8 @@
 using ResolverConstEvalShiftLeftConcreteGeqBitWidthError =
     ResolverTestWithParam<std::tuple<Types, Types>>;
 TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) {
-    auto* lhs_expr =
-        std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<0>(GetParam()));
-    auto* rhs_expr =
-        std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<1>(GetParam()));
+    auto* lhs_expr = ToValueBase(std::get<0>(GetParam()))->Expr(*this);
+    auto* rhs_expr = ToValueBase(std::get<1>(GetParam()))->Expr(*this);
     GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(
@@ -880,10 +864,8 @@
 // AInt left shift results in sign change error
 using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam<std::tuple<Types, Types>>;
 TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) {
-    auto* lhs_expr =
-        std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<0>(GetParam()));
-    auto* rhs_expr =
-        std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<1>(GetParam()));
+    auto* lhs_expr = ToValueBase(std::get<0>(GetParam()))->Expr(*this);
+    auto* rhs_expr = ToValueBase(std::get<1>(GetParam()))->Expr(*this);
     GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change");
diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc
index 3936fba..86223ba 100644
--- a/src/tint/resolver/const_eval_builtin_test.cc
+++ b/src/tint/resolver/const_eval_builtin_test.cc
@@ -83,54 +83,57 @@
         std::visit([&](auto&& v) { args.Push(v.Expr(*this)); }, a);
     }
 
-    std::visit(
-        [&](auto&& expected) {
-            using T = typename std::decay_t<decltype(expected)>::ElementType;
-            auto* expr = Call(sem::str(builtin), std::move(args));
+    auto* expected = ToValueBase(c.expected);
+    auto* expr = Call(sem::str(builtin), std::move(args));
 
-            GlobalConst("C", expr);
-            auto* expected_expr = expected.Expr(*this);
-            GlobalConst("E", expected_expr);
+    GlobalConst("C", expr);
+    auto* expected_expr = expected->Expr(*this);
+    GlobalConst("E", expected_expr);
 
-            EXPECT_TRUE(r()->Resolve()) << r()->error();
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
 
-            auto* sem = Sem().Get(expr);
-            const sem::Constant* value = sem->ConstantValue();
-            ASSERT_NE(value, nullptr);
-            EXPECT_TYPE(value->Type(), sem->Type());
+    auto* sem = Sem().Get(expr);
+    const sem::Constant* value = sem->ConstantValue();
+    ASSERT_NE(value, nullptr);
+    EXPECT_TYPE(value->Type(), sem->Type());
 
-            auto* expected_sem = Sem().Get(expected_expr);
-            const sem::Constant* expected_value = expected_sem->ConstantValue();
-            ASSERT_NE(expected_value, nullptr);
-            EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
+    auto* expected_sem = Sem().Get(expected_expr);
+    const sem::Constant* expected_value = expected_sem->ConstantValue();
+    ASSERT_NE(expected_value, nullptr);
+    EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
 
-            ForEachElemPair(value, expected_value,
-                            [&](const sem::Constant* a, const sem::Constant* b) {
-                                auto v = a->As<T>();
-                                auto e = b->As<T>();
-                                if constexpr (std::is_same_v<bool, T>) {
-                                    EXPECT_EQ(v, e);
-                                } else if constexpr (IsFloatingPoint<T>) {
-                                    if (std::isnan(e)) {
-                                        EXPECT_TRUE(std::isnan(v));
-                                    } else {
-                                        auto vf = (c.expected_pos_or_neg ? Abs(v) : v);
-                                        if (c.float_compare) {
-                                            EXPECT_FLOAT_EQ(vf, e);
-                                        } else {
-                                            EXPECT_EQ(vf, e);
-                                        }
-                                    }
-                                } else {
-                                    EXPECT_EQ((c.expected_pos_or_neg ? Abs(v) : v), e);
-                                    // Check that the constant's integer doesn't contain unexpected
-                                    // data in the MSBs that are outside of the bit-width of T.
-                                    EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
-                                }
-                                return HasFailure() ? Action::kStop : Action::kContinue;
-                            });
-        },
-        c.expected);
+    // @TODO(amaiorano): Rewrite using ScalarArgsFrom()
+    ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) {
+        std::visit(
+            [&](auto&& ct_expected) {
+                using T = typename std::decay_t<decltype(ct_expected)>::ElementType;
+
+                auto v = a->As<T>();
+                auto e = b->As<T>();
+                if constexpr (std::is_same_v<bool, T>) {
+                    EXPECT_EQ(v, e);
+                } else if constexpr (IsFloatingPoint<T>) {
+                    if (std::isnan(e)) {
+                        EXPECT_TRUE(std::isnan(v));
+                    } else {
+                        auto vf = (c.expected_pos_or_neg ? Abs(v) : v);
+                        if (c.float_compare) {
+                            EXPECT_FLOAT_EQ(vf, e);
+                        } else {
+                            EXPECT_EQ(vf, e);
+                        }
+                    }
+                } else {
+                    EXPECT_EQ((c.expected_pos_or_neg ? Abs(v) : v), e);
+                    // Check that the constant's integer doesn't contain unexpected
+                    // data in the MSBs that are outside of the bit-width of T.
+                    EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
+                }
+            },
+            c.expected);
+
+        return HasFailure() ? Action::kStop : Action::kContinue;
+    });
 }
 
 INSTANTIATE_TEST_SUITE_P(  //
diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc
index 35657ee..7e6a6fc 100644
--- a/src/tint/resolver/const_eval_conversion_test.cc
+++ b/src/tint/resolver/const_eval_conversion_test.cc
@@ -29,20 +29,7 @@
     builder::Value<bool>>;
 
 static std::ostream& operator<<(std::ostream& o, const Scalar& scalar) {
-    std::visit(
-        [&](auto&& v) {
-            using ValueType = std::decay_t<decltype(v)>;
-            o << ValueType::DataType::Name() << "(";
-            for (auto& a : v.args.values) {
-                o << std::get<typename ValueType::ElementType>(a);
-                if (&a != &v.args.values.Back()) {
-                    o << ", ";
-                }
-            }
-            o << ")";
-        },
-        scalar);
-    return o;
+    return ToValueBase(scalar)->Print(o);
 }
 
 enum class Kind {
@@ -96,7 +83,7 @@
     const auto& type = std::get<1>(GetParam()).type;
     const auto unrepresentable = std::get<1>(GetParam()).unrepresentable;
 
-    auto* input_val = std::visit([&](auto val) { return val.Expr(*this); }, input);
+    auto* input_val = ToValueBase(input)->Expr(*this);
     auto* expr = Construct(type.ast(*this), input_val);
     if (kind == Kind::kVector) {
         expr = Construct(ty.vec(nullptr, 3), expr);
@@ -120,7 +107,7 @@
         ASSERT_NE(sem->ConstantValue(), nullptr);
         EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
 
-        auto expected_values = std::visit([&](auto&& val) { return val.args; }, expected);
+        auto expected_values = ToValueBase(expected)->Args();
         if (kind == Kind::kVector) {
             expected_values.values.Push(expected_values.values[0]);
             expected_values.values.Push(expected_values.values[0]);
diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h
index 3840540..2761daf 100644
--- a/src/tint/resolver/const_eval_test.h
+++ b/src/tint/resolver/const_eval_test.h
@@ -41,6 +41,8 @@
 inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) {
     Switch(
         c->Type(),  //
+        [&](const sem::AbstractInt*) { args.values.Push(c->As<AInt>()); },
+        [&](const sem::AbstractFloat*) { args.values.Push(c->As<AFloat>()); },
         [&](const sem::Bool*) { args.values.Push(c->As<bool>()); },
         [&](const sem::I32*) { args.values.Push(c->As<i32>()); },
         [&](const sem::U32*) { args.values.Push(c->As<u32>()); },
@@ -136,6 +138,7 @@
 using builder::Mat;
 using builder::Val;
 using builder::Value;
+using builder::ValueBase;
 using builder::Vec;
 
 using Types = std::variant<  //
@@ -188,21 +191,18 @@
     //
     >;
 
+/// Returns the current Value<T> in the `types` variant as a `ValueBase` pointer to use the
+/// polymorphic API. This trades longer compile times using std::variant for longer runtime via
+/// virtual function calls.
+template <typename ValueVariant>
+inline const ValueBase* ToValueBase(const ValueVariant& types) {
+    return std::visit(
+        [](auto&& t) -> const ValueBase* { return static_cast<const ValueBase*>(&t); }, types);
+}
+
+/// Prints Types to ostream
 inline std::ostream& operator<<(std::ostream& o, const Types& types) {
-    std::visit(
-        [&](auto&& v) {
-            using ValueType = std::decay_t<decltype(v)>;
-            o << ValueType::DataType::Name() << "(";
-            for (auto& a : v.args.values) {
-                o << std::get<typename ValueType::ElementType>(a);
-                if (&a != &v.args.values.Back()) {
-                    o << ", ";
-                }
-            }
-            o << ")";
-        },
-        types);
-    return o;
+    return ToValueBase(types)->Print(o);
 }
 
 // Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops
diff --git a/src/tint/resolver/const_eval_unary_op_test.cc b/src/tint/resolver/const_eval_unary_op_test.cc
index 80b8caa..fced490 100644
--- a/src/tint/resolver/const_eval_unary_op_test.cc
+++ b/src/tint/resolver/const_eval_unary_op_test.cc
@@ -51,40 +51,34 @@
 
     auto op = std::get<0>(GetParam());
     auto& c = std::get<1>(GetParam());
-    std::visit(
-        [&](auto&& expected) {
-            using T = typename std::decay_t<decltype(expected)>::ElementType;
 
-            auto* input_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.input);
-            auto* expr = create<ast::UnaryOpExpression>(op, input_expr);
+    auto* expected = ToValueBase(c.expected);
+    auto* input = ToValueBase(c.input);
 
-            GlobalConst("C", expr);
-            auto* expected_expr = expected.Expr(*this);
-            GlobalConst("E", expected_expr);
-            ASSERT_TRUE(r()->Resolve()) << r()->error();
+    auto* input_expr = input->Expr(*this);
+    auto* expr = create<ast::UnaryOpExpression>(op, input_expr);
 
-            auto* sem = Sem().Get(expr);
-            const sem::Constant* value = sem->ConstantValue();
-            ASSERT_NE(value, nullptr);
-            EXPECT_TYPE(value->Type(), sem->Type());
+    GlobalConst("C", expr);
+    ASSERT_TRUE(r()->Resolve()) << r()->error();
 
-            auto* expected_sem = Sem().Get(expected_expr);
-            const sem::Constant* expected_value = expected_sem->ConstantValue();
-            ASSERT_NE(expected_value, nullptr);
-            EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
+    auto* sem = Sem().Get(expr);
+    const sem::Constant* value = sem->ConstantValue();
+    ASSERT_NE(value, nullptr);
+    EXPECT_TYPE(value->Type(), sem->Type());
 
-            ForEachElemPair(value, expected_value,
-                            [&](const sem::Constant* a, const sem::Constant* b) {
-                                EXPECT_EQ(a->As<T>(), b->As<T>());
-                                if constexpr (IsIntegral<T>) {
-                                    // Check that the constant's integer doesn't contain unexpected
-                                    // data in the MSBs that are outside of the bit-width of T.
-                                    EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
-                                }
-                                return HasFailure() ? Action::kStop : Action::kContinue;
-                            });
-        },
-        c.expected);
+    auto values_flat = ScalarArgsFrom(value);
+    auto expected_values_flat = expected->Args();
+    ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
+    for (size_t i = 0; i < values_flat.values.Length(); ++i) {
+        auto& a = values_flat.values[i];
+        auto& b = expected_values_flat.values[i];
+        EXPECT_EQ(a, b);
+        if (expected->IsIntegral()) {
+            // Check that the constant's integer doesn't contain unexpected
+            // data in the MSBs that are outside of the bit-width of T.
+            EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
+        }
+    }
 }
 INSTANTIATE_TEST_SUITE_P(Complement,
                          ResolverConstEvalUnaryOpTest,
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index 923f3ff..57fe14a 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -206,6 +206,12 @@
     utils::Vector<Storage, 16> values;
 };
 
+/// Returns current variant value in `s` cast to type `T`
+template <typename T>
+T As(ScalarArgs::Storage& s) {
+    return std::visit([](auto&& v) { return static_cast<T>(v); }, s);
+}
+
 /// @param o the std::ostream to write to
 /// @param args the ScalarArgs
 /// @return the std::ostream so calls can be chained
@@ -750,10 +756,45 @@
             DataType<T>::Name};
 }
 
+/// Base class for Value<T>
+struct ValueBase {
+    /// Constructor
+    ValueBase() = default;
+    /// Destructor
+    virtual ~ValueBase() = default;
+    /// Move constructor
+    ValueBase(ValueBase&&) = default;
+    /// Copy constructor
+    ValueBase(const ValueBase&) = default;
+    /// Copy assignment operator
+    /// @returns this instance
+    ValueBase& operator=(const ValueBase&) = default;
+    /// Creates an `ast::Expression` for the type T passing in previously stored args
+    /// @param b the ProgramBuilder
+    /// @returns an expression node
+    virtual const ast::Expression* Expr(ProgramBuilder& b) const = 0;
+    /// @returns args used to create expression via `Expr`
+    virtual const ScalarArgs& Args() const = 0;
+    /// @returns true if element type is abstract
+    virtual bool IsAbstract() const = 0;
+    /// @returns true if element type is an integral
+    virtual bool IsIntegral() const = 0;
+    /// @returns element type name
+    virtual std::string TypeName() const = 0;
+    /// Prints this value to the output stream
+    /// @param o the output stream
+    /// @returns input argument `o`
+    virtual std::ostream& Print(std::ostream& o) const = 0;
+};
+
 /// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create
 /// expressions with.
 template <typename T>
-struct Value {
+struct Value : ValueBase {
+    /// Constructor
+    /// @param a the scalar args
+    explicit Value(ScalarArgs a) : args(std::move(a)) {}
+
     /// Alias to T
     using Type = T;
     /// Alias to DataType<T>
@@ -764,15 +805,43 @@
     /// Creates a Value<T> with `args`
     /// @param args the args that will be passed to the expression
     /// @returns a Value<T>
-    static Value Create(ScalarArgs args) { return Value{CreatePtrsFor<T>(), std::move(args)}; }
+    static Value Create(ScalarArgs args) { return Value{std::move(args)}; }
 
     /// Creates an `ast::Expression` for the type T passing in previously stored args
     /// @param b the ProgramBuilder
     /// @returns an expression node
-    const ast::Expression* Expr(ProgramBuilder& b) const { return (*create.expr)(b, args); }
+    const ast::Expression* Expr(ProgramBuilder& b) const override {
+        auto create = CreatePtrsFor<T>();
+        return (*create.expr)(b, args);
+    }
 
-    /// functions to create values / types of the value
-    CreatePtrs create;
+    /// @returns args used to create expression via `Expr`
+    const ScalarArgs& Args() const override { return args; }
+
+    /// @returns true if element type is abstract
+    bool IsAbstract() const override { return tint::IsAbstract<ElementType>; }
+
+    /// @returns true if element type is an integral
+    bool IsIntegral() const override { return tint::IsIntegral<ElementType>; }
+
+    /// @returns element type name
+    std::string TypeName() const override { return tint::FriendlyName<ElementType>(); }
+
+    /// Prints this value to the output stream
+    /// @param o the output stream
+    /// @returns input argument `o`
+    std::ostream& Print(std::ostream& o) const override {
+        o << TypeName() << "(";
+        for (auto& a : args.values) {
+            o << std::get<ElementType>(a);
+            if (&a != &args.values.Back()) {
+                o << ", ";
+            }
+        }
+        o << ")";
+        return o;
+    }
+
     /// args to create expression with
     ScalarArgs args;
 };