tint: refactor ConstEval API to accept vector of constants rather than of expressions

This is needed for a follow-up change to apply implicit conversions for
AFloat to AInt.

Bug: chromium:1350147
Change-Id: Id903322d01b7aa420452c3e0fc1fa4e1c480c794
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/98683
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 7774a33..25766e6 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -1978,6 +1978,17 @@
                                              Expr(std::forward<RHS>(rhs)));
     }
 
+    /// @param source the source information
+    /// @param lhs the left hand argument to the addition operation
+    /// @param rhs the right hand argument to the addition operation
+    /// @returns a `ast::BinaryExpression` summing the arguments `lhs` and `rhs`
+    template <typename LHS, typename RHS>
+    const ast::BinaryExpression* Add(const Source& source, LHS&& lhs, RHS&& rhs) {
+        return create<ast::BinaryExpression>(source, ast::BinaryOp::kAdd,
+                                             Expr(std::forward<LHS>(lhs)),
+                                             Expr(std::forward<RHS>(rhs)));
+    }
+
     /// @param lhs the left hand argument to the and operation
     /// @param rhs the right hand argument to the and operation
     /// @returns a `ast::BinaryExpression` bitwise anding `lhs` and `rhs`
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 26a0403..5144e63 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -509,15 +509,6 @@
     return CreateComposite(builder, ty, std::move(els));
 }
 
-/// CombineSource returns the combined `Source`s of each expression in `exprs`.
-Source CombineSource(utils::VectorRef<const sem::Expression*> exprs) {
-    Source result = exprs[0]->Declaration()->source;
-    for (size_t i = 1; i < exprs.Length(); ++i) {
-        result = result.Combine(result, exprs[i]->Declaration()->source);
-    }
-    return result;
-}
-
 }  // namespace
 
 ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
@@ -575,20 +566,19 @@
 }
 
 ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty,
-                                          utils::VectorRef<const sem::Expression*> args) {
+                                          utils::VectorRef<const sem::Constant*> args,
+                                          const Source& source) {
     uint32_t el_count = 0;
     auto* el_ty = sem::Type::ElementOf(ty, &el_count);
     if (!el_ty) {
         return nullptr;
     }
 
-    auto& src = args[0]->Declaration()->source;
-    auto* arg = args[0]->ConstantValue();
-    if (!arg) {
+    if (!args[0]) {
         return nullptr;  // Single argument is not constant.
     }
 
-    if (auto conv = Convert(ty, arg, src)) {
+    if (auto conv = Convert(ty, args[0], source)) {
         return conv.Get();
     }
 
@@ -596,37 +586,38 @@
 }
 
 ConstEval::ConstantResult ConstEval::Zero(const sem::Type* ty,
-                                          utils::VectorRef<const sem::Expression*>) {
+                                          utils::VectorRef<const sem::Constant*>,
+                                          const Source&) {
     return ZeroValue(builder, ty);
 }
 
 ConstEval::ConstantResult ConstEval::Identity(const sem::Type*,
-                                              utils::VectorRef<const sem::Expression*> args) {
-    return args[0]->ConstantValue();
+                                              utils::VectorRef<const sem::Constant*> args,
+                                              const Source&) {
+    return args[0];
 }
 
 ConstEval::ConstantResult ConstEval::VecSplat(const sem::Type* ty,
-                                              utils::VectorRef<const sem::Expression*> args) {
-    if (auto* arg = args[0]->ConstantValue()) {
+                                              utils::VectorRef<const sem::Constant*> args,
+                                              const Source&) {
+    if (auto* arg = args[0]) {
         return builder.create<Splat>(ty, arg, static_cast<const sem::Vector*>(ty)->Width());
     }
     return nullptr;
 }
 
 ConstEval::ConstantResult ConstEval::VecCtorS(const sem::Type* ty,
-                                              utils::VectorRef<const sem::Expression*> args) {
-    utils::Vector<const sem::Constant*, 4> els;
-    for (auto* arg : args) {
-        els.Push(arg->ConstantValue());
-    }
-    return CreateComposite(builder, ty, std::move(els));
+                                              utils::VectorRef<const sem::Constant*> args,
+                                              const Source&) {
+    return CreateComposite(builder, ty, args);
 }
 
 ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty,
-                                              utils::VectorRef<const sem::Expression*> args) {
+                                              utils::VectorRef<const sem::Constant*> args,
+                                              const Source&) {
     utils::Vector<const sem::Constant*, 4> els;
     for (auto* arg : args) {
-        auto* val = arg->ConstantValue();
+        auto* val = arg;
         if (!val) {
             return nullptr;
         }
@@ -648,7 +639,8 @@
 }
 
 ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty,
-                                              utils::VectorRef<const sem::Expression*> args) {
+                                              utils::VectorRef<const sem::Constant*> args,
+                                              const Source&) {
     auto* m = static_cast<const sem::Matrix*>(ty);
 
     utils::Vector<const sem::Constant*, 4> els;
@@ -656,7 +648,7 @@
         utils::Vector<const sem::Constant*, 4> column;
         for (uint32_t r = 0; r < m->rows(); r++) {
             auto i = r + c * m->rows();
-            column.Push(args[i]->ConstantValue());
+            column.Push(args[i]);
         }
         els.Push(CreateComposite(builder, m->ColumnType(), std::move(column)));
     }
@@ -664,12 +656,9 @@
 }
 
 ConstEval::ConstantResult ConstEval::MatCtorV(const sem::Type* ty,
-                                              utils::VectorRef<const sem::Expression*> args) {
-    utils::Vector<const sem::Constant*, 4> els;
-    for (auto* arg : args) {
-        els.Push(arg->ConstantValue());
-    }
-    return CreateComposite(builder, ty, std::move(els));
+                                              utils::VectorRef<const sem::Constant*> args,
+                                              const Source&) {
+    return CreateComposite(builder, ty, args);
 }
 
 ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr,
@@ -731,18 +720,20 @@
 }
 
 ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type*,
-                                                  utils::VectorRef<const sem::Expression*> args) {
+                                                  utils::VectorRef<const sem::Constant*> args,
+                                                  const Source&) {
     auto transform = [&](const sem::Constant* c) {
         auto create = [&](auto i) {
             return CreateElement(builder, c->Type(), decltype(i)(~i.value));
         };
         return Dispatch_ia_iu32(create, c);
     };
-    return TransformElements(builder, transform, args[0]->ConstantValue());
+    return TransformElements(builder, transform, args[0]);
 }
 
 ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*,
-                                             utils::VectorRef<const sem::Expression*> args) {
+                                             utils::VectorRef<const sem::Constant*> args,
+                                             const Source&) {
     auto transform = [&](const sem::Constant* c) {
         auto create = [&](auto i) {
             // For signed integrals, avoid C++ UB by not negating the
@@ -762,11 +753,12 @@
         };
         return Dispatch_fia_fi32_f16(create, c);
     };
-    return TransformElements(builder, transform, args[0]->ConstantValue());
+    return TransformElements(builder, transform, args[0]);
 }
 
 ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
-                                            utils::VectorRef<const sem::Expression*> args) {
+                                            utils::VectorRef<const sem::Constant*> args,
+                                            const Source& source) {
     auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
         auto create = [&](auto i, auto j) -> const Constant* {
             using NumberT = decltype(i);
@@ -791,7 +783,7 @@
                     AddError("'" + std::to_string(add_values(i.value, j.value)) +
                                  "' cannot be represented as '" +
                                  ty->FriendlyName(builder.Symbols()) + "'",
-                             CombineSource(args));
+                             source);
                     return nullptr;
                 }
             } else {
@@ -802,8 +794,7 @@
         return Dispatch_fia_fiu32_f16(create, c0, c1);
     };
 
-    auto r = TransformBinaryElements(builder, transform, args[0]->ConstantValue(),
-                                     args[1]->ConstantValue());
+    auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
     if (builder.Diagnostics().contains_errors()) {
         return utils::Failure;
     }
@@ -811,19 +802,20 @@
 }
 
 ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
-                                           utils::VectorRef<const sem::Expression*> args) {
+                                           utils::VectorRef<const sem::Constant*> args,
+                                           const Source&) {
     auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
         auto create = [&](auto i, auto j) {
             return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value)));
         };
         return Dispatch_fa_f32_f16(create, c0, c1);
     };
-    return TransformElements(builder, transform, args[0]->ConstantValue(),
-                             args[1]->ConstantValue());
+    return TransformElements(builder, transform, args[0], args[1]);
 }
 
 ConstEval::ConstantResult ConstEval::clamp(const sem::Type*,
-                                           utils::VectorRef<const sem::Expression*> args) {
+                                           utils::VectorRef<const sem::Constant*> args,
+                                           const Source&) {
     auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
                          const sem::Constant* c2) {
         auto create = [&](auto e, auto low, auto high) {
@@ -832,8 +824,7 @@
         };
         return Dispatch_fia_fiu32_f16(create, c0, c1, c2);
     };
-    return TransformElements(builder, transform, args[0]->ConstantValue(), args[1]->ConstantValue(),
-                             args[2]->ConstantValue());
+    return TransformElements(builder, transform, args[0], args[1], args[2]);
 }
 
 utils::Result<const sem::Constant*> ConstEval::Convert(const sem::Type* target_ty,
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index 38bde53..dbc3dbd 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -57,7 +57,8 @@
 
     /// Typedef for a constant evaluation function
     using Function = ConstantResult (ConstEval::*)(const sem::Type* result_ty,
-                                                   utils::VectorRef<const sem::Expression*>);
+                                                   utils::VectorRef<const sem::Constant*>,
+                                                   const Source&);
 
     /// Constructor
     /// @param b the program builder
@@ -116,50 +117,74 @@
     /// Type conversion
     /// @param ty the result type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the converted value, or null if the value cannot be calculated
-    ConstantResult Conv(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult Conv(const sem::Type* ty,
+                        utils::VectorRef<const sem::Constant*> args,
+                        const Source& source);
 
     /// Zero value type constructor
     /// @param ty the result type
     /// @param args the input arguments (no arguments provided)
+    /// @param source the source location of the conversion
     /// @return the constructed value, or null if the value cannot be calculated
-    ConstantResult Zero(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult Zero(const sem::Type* ty,
+                        utils::VectorRef<const sem::Constant*> args,
+                        const Source& source);
 
     /// Identity value type constructor
     /// @param ty the result type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the constructed value, or null if the value cannot be calculated
-    ConstantResult Identity(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult Identity(const sem::Type* ty,
+                            utils::VectorRef<const sem::Constant*> args,
+                            const Source& source);
 
     /// Vector splat constructor
     /// @param ty the vector type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the constructed value, or null if the value cannot be calculated
-    ConstantResult VecSplat(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult VecSplat(const sem::Type* ty,
+                            utils::VectorRef<const sem::Constant*> args,
+                            const Source& source);
 
     /// Vector constructor using scalars
     /// @param ty the vector type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the constructed value, or null if the value cannot be calculated
-    ConstantResult VecCtorS(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult VecCtorS(const sem::Type* ty,
+                            utils::VectorRef<const sem::Constant*> args,
+                            const Source& source);
 
     /// Vector constructor using a mix of scalars and smaller vectors
     /// @param ty the vector type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the constructed value, or null if the value cannot be calculated
-    ConstantResult VecCtorM(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult VecCtorM(const sem::Type* ty,
+                            utils::VectorRef<const sem::Constant*> args,
+                            const Source& source);
 
     /// Matrix constructor using scalar values
     /// @param ty the matrix type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the constructed value, or null if the value cannot be calculated
-    ConstantResult MatCtorS(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult MatCtorS(const sem::Type* ty,
+                            utils::VectorRef<const sem::Constant*> args,
+                            const Source& source);
 
     /// Matrix constructor using column vectors
     /// @param ty the matrix type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the constructed value, or null if the value cannot be calculated
-    ConstantResult MatCtorV(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult MatCtorV(const sem::Type* ty,
+                            utils::VectorRef<const sem::Constant*> args,
+                            const Source& source);
 
     ////////////////////////////////////////////////////////////////////////////
     // Unary Operators
@@ -168,14 +193,20 @@
     /// Complement operator '~'
     /// @param ty the integer type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the result value, or null if the value cannot be calculated
-    ConstantResult OpComplement(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult OpComplement(const sem::Type* ty,
+                                utils::VectorRef<const sem::Constant*> args,
+                                const Source& source);
 
     /// Minus operator '-'
     /// @param ty the expression type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the result value, or null if the value cannot be calculated
-    ConstantResult OpMinus(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult OpMinus(const sem::Type* ty,
+                           utils::VectorRef<const sem::Constant*> args,
+                           const Source& source);
 
     ////////////////////////////////////////////////////////////////////////////
     // Binary Operators
@@ -184,8 +215,11 @@
     /// Plus operator '+'
     /// @param ty the expression type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the result value, or null if the value cannot be calculated
-    ConstantResult OpPlus(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult OpPlus(const sem::Type* ty,
+                          utils::VectorRef<const sem::Constant*> args,
+                          const Source& source);
 
     ////////////////////////////////////////////////////////////////////////////
     // Builtins
@@ -194,14 +228,20 @@
     /// atan2 builtin
     /// @param ty the expression type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the result value, or null if the value cannot be calculated
-    ConstantResult atan2(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult atan2(const sem::Type* ty,
+                         utils::VectorRef<const sem::Constant*> args,
+                         const Source& source);
 
     /// clamp builtin
     /// @param ty the expression type
     /// @param args the input arguments
+    /// @param source the source location of the conversion
     /// @return the result value, or null if the value cannot be calculated
-    ConstantResult clamp(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
+    ConstantResult clamp(const sem::Type* ty,
+                         utils::VectorRef<const sem::Constant*> args,
+                         const Source& source);
 
   private:
     /// Adds the given error message to the diagnostics
diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc
index e7c342d..9377fb8 100644
--- a/src/tint/resolver/const_eval_test.cc
+++ b/src/tint/resolver/const_eval_test.cc
@@ -3229,27 +3229,27 @@
                                               ))));
 
 TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
-    GlobalConst("c", nullptr, Add(Expr(Source{{1, 1}}, AInt::Highest()), 1_a));
+    GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
               "1:1 error: '-9223372036854775808' cannot be represented as 'abstract-int'");
 }
 
 TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AInt) {
-    GlobalConst("c", nullptr, Add(Expr(Source{{1, 1}}, AInt::Lowest()), -1_a));
+    GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Lowest()), -1_a));
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(),
               "1:1 error: '9223372036854775807' cannot be represented as 'abstract-int'");
 }
 
 TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AFloat) {
-    GlobalConst("c", nullptr, Add(Expr(Source{{1, 1}}, AFloat::Highest()), AFloat::Highest()));
+    GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AFloat::Highest()), AFloat::Highest()));
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(), "1:1 error: 'inf' cannot be represented as 'abstract-float'");
 }
 
 TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) {
-    GlobalConst("c", nullptr, Add(Expr(Source{{1, 1}}, AFloat::Lowest()), AFloat::Lowest()));
+    GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AFloat::Lowest()), AFloat::Lowest()));
     EXPECT_FALSE(r()->Resolve());
     EXPECT_EQ(r()->error(), "1:1 error: '-inf' cannot be represented as 'abstract-float'");
 }
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 046cf6f..5d52c14 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1585,8 +1585,10 @@
         const sem::Constant* value = nullptr;
         auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage);
         if (stage == sem::EvaluationStage::kConstant) {
+            auto const_args =
+                utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); });
             if (auto r = (const_eval_.*ctor_or_conv.const_eval_fn)(
-                    ctor_or_conv.target->ReturnType(), args)) {
+                    ctor_or_conv.target->ReturnType(), const_args, expr->source)) {
                 value = r.Get();
             } else {
                 return nullptr;
@@ -1891,7 +1893,9 @@
     // If the builtin is @const, and all arguments have constant values, evaluate the builtin now.
     const sem::Constant* value = nullptr;
     if (stage == sem::EvaluationStage::kConstant) {
-        if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args)) {
+        auto const_args = utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); });
+        if (auto r = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), const_args,
+                                                          expr->source)) {
             value = r.Get();
         } else {
             return nullptr;
@@ -2297,7 +2301,8 @@
     auto stage = sem::EarliestStage(lhs->Stage(), rhs->Stage());
     if (stage == sem::EvaluationStage::kConstant) {
         if (op.const_eval_fn) {
-            if (auto r = (const_eval_.*op.const_eval_fn)(op.result, utils::Vector{lhs, rhs})) {
+            auto const_args = utils::Vector{lhs->ConstantValue(), rhs->ConstantValue()};
+            if (auto r = (const_eval_.*op.const_eval_fn)(op.result, const_args, expr->source)) {
                 value = r.Get();
             } else {
                 return nullptr;
@@ -2380,7 +2385,9 @@
             stage = expr->Stage();
             if (stage == sem::EvaluationStage::kConstant) {
                 if (op.const_eval_fn) {
-                    if (auto r = (const_eval_.*op.const_eval_fn)(ty, utils::Vector{expr})) {
+                    if (auto r = (const_eval_.*op.const_eval_fn)(
+                            ty, utils::Vector{expr->ConstantValue()},
+                            expr->Declaration()->source)) {
                         value = r.Get();
                     } else {
                         return nullptr;