Move CreateComposite into ProgramBuilder.

This CL moves the CreateComposite helper into the ProgramBuilder.

Bug: tint:1718
Change-Id: I4aca7dc3d7192a7aa8b300f00529670aa9c09a27
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114202
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 83ad287..597463c 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -87,6 +87,8 @@
 #include "src/tint/ast/void.h"
 #include "src/tint/ast/while_statement.h"
 #include "src/tint/ast/workgroup_attribute.h"
+#include "src/tint/constant/composite.h"
+#include "src/tint/constant/splat.h"
 #include "src/tint/constant/value.h"
 #include "src/tint/number.h"
 #include "src/tint/program.h"
@@ -470,11 +472,73 @@
     /// @param args the arguments to pass to the constructor
     /// @returns the node pointer
     template <typename T, typename... ARGS>
-    traits::EnableIf<traits::IsTypeOrDerived<T, constant::Value>, T>* create(ARGS&&... args) {
+    traits::EnableIf<traits::IsTypeOrDerived<T, constant::Value> &&
+                         !traits::IsTypeOrDerived<T, constant::Composite> &&
+                         !traits::IsTypeOrDerived<T, constant::Splat>,
+                     T>*
+    create(ARGS&&... args) {
         AssertNotMoved();
         return constant_nodes_.Create<T>(std::forward<ARGS>(args)...);
     }
 
+    /// Constructs a constant of a vector, matrix or array type.
+    ///
+    /// Examines the element values and will return either a constant::Composite or a
+    /// constant::Splat, depending on the element types and values.
+    ///
+    /// @param type the composite type
+    /// @param elements the composite elements
+    /// @returns the node pointer
+    template <typename T>
+    traits::EnableIf<traits::IsTypeOrDerived<T, constant::Composite> ||
+                         traits::IsTypeOrDerived<T, constant::Splat>,
+                     const constant::Value>*
+    create(const type::Type* type, utils::VectorRef<const constant::Value*> elements) {
+        AssertNotMoved();
+        if (elements.IsEmpty()) {
+            return nullptr;
+        }
+
+        bool any_zero = false;
+        bool all_zero = true;
+        bool all_equal = true;
+        auto* first = elements.Front();
+        for (auto* el : elements) {
+            if (!el) {
+                return nullptr;
+            }
+            if (!any_zero && el->AnyZero()) {
+                any_zero = true;
+            }
+            if (all_zero && !el->AllZero()) {
+                all_zero = false;
+            }
+            if (all_equal && el != first) {
+                if (!el->Equal(first)) {
+                    all_equal = false;
+                }
+            }
+        }
+        if (all_equal) {
+            return create<constant::Splat>(type, elements[0], elements.Length());
+        }
+
+        return constant_nodes_.Create<constant::Composite>(type, std::move(elements), all_zero,
+                                                           any_zero);
+    }
+
+    /// Constructs a splat constant.
+    /// @param type the splat type
+    /// @param element the splat element
+    /// @param n the number of elements
+    /// @returns the node pointer
+    template <typename T>
+    traits::EnableIf<traits::IsTypeOrDerived<T, constant::Splat>, const constant::Splat>*
+    create(const type::Type* type, const constant::Value* element, size_t n) {
+        AssertNotMoved();
+        return constant_nodes_.Create<constant::Splat>(type, element, n);
+    }
+
     /// Creates a new type::Type owned by the ProgramBuilder.
     /// When the ProgramBuilder is destructed, owned ProgramBuilder and the
     /// returned `Type` will also be destructed.
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index a92b02a..628975b 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -232,11 +232,6 @@
     return count;
 }
 
-// Forward declaration
-const constant::Value* CreateComposite(ProgramBuilder& builder,
-                                       const type::Type* type,
-                                       utils::VectorRef<const constant::Value*> elements);
-
 template <typename T>
 ConstEval::Result ScalarConvert(const constant::Scalar<T>* scalar,
                                 ProgramBuilder& builder,
@@ -347,7 +342,7 @@
         }
         conv_els.Push(conv_el.Get());
     }
-    return CreateComposite(builder, target_ty, std::move(conv_els));
+    return builder.create<constant::Composite>(target_ty, std::move(conv_els));
 }
 
 ConstEval::Result ConvertInternal(const constant::Value* c,
@@ -438,7 +433,7 @@
                 // All members were of the same type, so the zero value is the same for all members.
                 return builder.create<constant::Splat>(type, zeros[0], s->Members().Length());
             }
-            return CreateComposite(builder, s, std::move(zeros));
+            return builder.create<constant::Composite>(s, std::move(zeros));
         },
         [&](Default) -> const constant::Value* {
             return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* {
@@ -449,42 +444,6 @@
         });
 }
 
-/// CreateComposite is used to construct a constant of a vector, matrix or array type.
-/// CreateComposite examines the element values and will return either a Composite or a Splat,
-/// depending on the element types and values.
-const constant::Value* CreateComposite(ProgramBuilder& builder,
-                                       const type::Type* type,
-                                       utils::VectorRef<const constant::Value*> elements) {
-    if (elements.IsEmpty()) {
-        return nullptr;
-    }
-    bool any_zero = false;
-    bool all_zero = true;
-    bool all_equal = true;
-    auto* first = elements.Front();
-    for (auto* el : elements) {
-        if (!el) {
-            return nullptr;
-        }
-        if (!any_zero && el->AnyZero()) {
-            any_zero = true;
-        }
-        if (all_zero && !el->AllZero()) {
-            all_zero = false;
-        }
-        if (all_equal && el != first) {
-            if (!el->Equal(first)) {
-                all_equal = false;
-            }
-        }
-    }
-    if (all_equal) {
-        return builder.create<constant::Splat>(type, elements[0], elements.Length());
-    } else {
-        return builder.create<constant::Composite>(type, std::move(elements), all_zero, any_zero);
-    }
-}
-
 namespace detail {
 /// Implementation of TransformElements
 template <typename F, typename... CONSTANTS>
@@ -515,7 +474,7 @@
             return el.Failure();
         }
     }
-    return CreateComposite(builder, composite_ty, std::move(els));
+    return builder.create<constant::Composite>(composite_ty, std::move(els));
 }
 }  // namespace detail
 
@@ -569,7 +528,7 @@
             return el.Failure();
         }
     }
-    return CreateComposite(builder, composite_ty, std::move(els));
+    return builder.create<constant::Composite>(composite_ty, std::move(els));
 }
 }  // namespace
 
@@ -1211,7 +1170,7 @@
     for (auto* arg : args) {
         els.Push(arg->ConstantValue());
     }
-    return CreateComposite(builder, ty, std::move(els));
+    return builder.create<constant::Composite>(ty, std::move(els));
 }
 
 ConstEval::Result ConstEval::Conv(const type::Type* ty,
@@ -1255,7 +1214,7 @@
 ConstEval::Result ConstEval::VecInitS(const type::Type* ty,
                                       utils::VectorRef<const constant::Value*> args,
                                       const Source&) {
-    return CreateComposite(builder, ty, args);
+    return builder.create<constant::Composite>(ty, args);
 }
 
 ConstEval::Result ConstEval::VecInitM(const type::Type* ty,
@@ -1281,7 +1240,7 @@
             els.Push(val);
         }
     }
-    return CreateComposite(builder, ty, std::move(els));
+    return builder.create<constant::Composite>(ty, std::move(els));
 }
 
 ConstEval::Result ConstEval::MatInitS(const type::Type* ty,
@@ -1296,15 +1255,15 @@
             auto i = r + c * m->rows();
             column.Push(args[i]);
         }
-        els.Push(CreateComposite(builder, m->ColumnType(), std::move(column)));
+        els.Push(builder.create<constant::Composite>(m->ColumnType(), std::move(column)));
     }
-    return CreateComposite(builder, ty, std::move(els));
+    return builder.create<constant::Composite>(ty, std::move(els));
 }
 
 ConstEval::Result ConstEval::MatInitV(const type::Type* ty,
                                       utils::VectorRef<const constant::Value*> args,
                                       const Source&) {
-    return CreateComposite(builder, ty, args);
+    return builder.create<constant::Composite>(ty, args);
 }
 
 ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr,
@@ -1357,7 +1316,7 @@
     }
     auto values = utils::Transform<4>(
         indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); });
-    return CreateComposite(builder, ty, std::move(values));
+    return builder.create<constant::Composite>(ty, std::move(values));
 }
 
 ConstEval::Result ConstEval::Bitcast(const type::Type*, const sem::Expression*) {
@@ -1484,7 +1443,7 @@
         }
         result.Push(r.Get());
     }
-    return CreateComposite(builder, ty, result);
+    return builder.create<constant::Composite>(ty, result);
 }
 ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty,
                                               utils::VectorRef<const constant::Value*> args,
@@ -1534,7 +1493,7 @@
         }
         result.Push(r.Get());
     }
-    return CreateComposite(builder, ty, result);
+    return builder.create<constant::Composite>(ty, result);
 }
 
 ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty,
@@ -1596,9 +1555,9 @@
 
         // Add column vector to matrix
         auto* col_vec_ty = ty->As<type::Matrix>()->ColumnType();
-        result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec));
+        result_mat.Push(builder.create<constant::Composite>(col_vec_ty, col_vec));
     }
-    return CreateComposite(builder, ty, result_mat);
+    return builder.create<constant::Composite>(ty, result_mat);
 }
 
 ConstEval::Result ConstEval::OpDivide(const type::Type* ty,
@@ -2208,8 +2167,8 @@
         return utils::Failure;
     }
 
-    return CreateComposite(builder, ty,
-                           utils::Vector<const constant::Value*, 3>{x.Get(), y.Get(), z.Get()});
+    return builder.create<constant::Composite>(
+        ty, utils::Vector<const constant::Value*, 3>{x.Get(), y.Get(), z.Get()});
 }
 
 ConstEval::Result ConstEval::degrees(const type::Type* ty,
@@ -2592,21 +2551,20 @@
         }
         auto fract_ty = builder.create<type::Vector>(fract_els[0]->Type(), vec->Width());
         auto exp_ty = builder.create<type::Vector>(exp_els[0]->Type(), vec->Width());
-        return CreateComposite(builder, ty,
-                               utils::Vector<const constant::Value*, 2>{
-                                   CreateComposite(builder, fract_ty, std::move(fract_els)),
-                                   CreateComposite(builder, exp_ty, std::move(exp_els)),
-                               });
+        return builder.create<constant::Composite>(
+            ty, utils::Vector<const constant::Value*, 2>{
+                    builder.create<constant::Composite>(fract_ty, std::move(fract_els)),
+                    builder.create<constant::Composite>(exp_ty, std::move(exp_els)),
+                });
     } else {
         auto fe = scalar(arg);
         if (!fe.fract || !fe.exp) {
             return utils::Failure;
         }
-        return CreateComposite(builder, ty,
-                               utils::Vector<const constant::Value*, 2>{
-                                   fe.fract.Get(),
-                                   fe.exp.Get(),
-                               });
+        return builder.create<constant::Composite>(ty, utils::Vector<const constant::Value*, 2>{
+                                                           fe.fract.Get(),
+                                                           fe.exp.Get(),
+                                                       });
     }
 }
 
@@ -2838,7 +2796,7 @@
         return utils::Failure;
     }
 
-    return CreateComposite(builder, ty, std::move(fields));
+    return builder.create<constant::Composite>(ty, std::move(fields));
 }
 
 ConstEval::Result ConstEval::normalize(const type::Type* ty,
@@ -3412,9 +3370,10 @@
         for (size_t c = 0; c < mat_ty->columns(); ++c) {
             new_col_vec.Push(me(r, c));
         }
-        result_mat.Push(CreateComposite(builder, result_mat_ty->ColumnType(), new_col_vec));
+        result_mat.Push(
+            builder.create<constant::Composite>(result_mat_ty->ColumnType(), new_col_vec));
     }
-    return CreateComposite(builder, ty, result_mat);
+    return builder.create<constant::Composite>(ty, result_mat);
 }
 
 ConstEval::Result ConstEval::trunc(const type::Type* ty,
@@ -3450,7 +3409,7 @@
         }
         els.Push(el.Get());
     }
-    return CreateComposite(builder, ty, std::move(els));
+    return builder.create<constant::Composite>(ty, std::move(els));
 }
 
 ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty,
@@ -3470,7 +3429,7 @@
         }
         els.Push(el.Get());
     }
-    return CreateComposite(builder, ty, std::move(els));
+    return builder.create<constant::Composite>(ty, std::move(els));
 }
 
 ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty,
@@ -3489,7 +3448,7 @@
         }
         els.Push(el.Get());
     }
-    return CreateComposite(builder, ty, std::move(els));
+    return builder.create<constant::Composite>(ty, std::move(els));
 }
 
 ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty,
@@ -3509,7 +3468,7 @@
         }
         els.Push(el.Get());
     }
-    return CreateComposite(builder, ty, std::move(els));
+    return builder.create<constant::Composite>(ty, std::move(els));
 }
 
 ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
@@ -3528,7 +3487,7 @@
         }
         els.Push(el.Get());
     }
-    return CreateComposite(builder, ty, std::move(els));
+    return builder.create<constant::Composite>(ty, std::move(els));
 }
 
 ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty,