Extract constant convert methods.

This CL pulls the convert methods out into standalone methods inside the
resolver and de-couples from the constants.

Bug: tint:1718
Change-Id: Id566704687b2d74e05eae860477552f88f6a06b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114120
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 94d60d4..d7de6cd 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -253,12 +253,6 @@
   public:
     ImplConstant() = default;
     ~ImplConstant() override = default;
-
-    /// Convert attempts to convert the constant value to the given type. On error, Convert()
-    /// creates a new diagnostic message and returns a Failure.
-    virtual utils::Result<const ImplConstant*> Convert(ProgramBuilder& builder,
-                                                       const type::Type* target_ty,
-                                                       const Source& source) const = 0;
 };
 
 /// A result templated with a ImplConstant.
@@ -297,62 +291,6 @@
     bool AllEqual() const override { return true; }
     size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
 
-    ImplResult Convert(ProgramBuilder& builder,
-                       const type::Type* target_ty,
-                       const Source& source) const override {
-        TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
-        if (target_ty == type) {
-            // If the types are identical, then no conversion is needed.
-            return this;
-        }
-        return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
-            // `value` is the source value.
-            // `FROM` is the source type.
-            // `TO` is the target type.
-            using TO = std::decay_t<decltype(zero_to)>;
-            using FROM = T;
-            if constexpr (std::is_same_v<TO, bool>) {
-                // [x -> bool]
-                return builder.create<Scalar<TO>>(target_ty, !IsPositiveZero(value));
-            } else if constexpr (std::is_same_v<FROM, bool>) {
-                // [bool -> x]
-                return builder.create<Scalar<TO>>(target_ty, TO(value ? 1 : 0));
-            } else if (auto conv = CheckedConvert<TO>(value)) {
-                // Conversion success
-                return builder.create<Scalar<TO>>(target_ty, conv.Get());
-                // --- Below this point are the failure cases ---
-            } else if constexpr (IsAbstract<FROM>) {
-                // [abstract-numeric -> x] - materialization failure
-                builder.Diagnostics().add_error(
-                    tint::diag::System::Resolver,
-                    OverflowErrorMessage(value, builder.FriendlyName(target_ty)), source);
-                return utils::Failure;
-            } else if constexpr (IsFloatingPoint<TO>) {
-                // [x -> floating-point] - number not exactly representable
-                // https://www.w3.org/TR/WGSL/#floating-point-conversion
-                builder.Diagnostics().add_error(
-                    tint::diag::System::Resolver,
-                    OverflowErrorMessage(value, builder.FriendlyName(target_ty)), source);
-                return utils::Failure;
-            } else if constexpr (IsFloatingPoint<FROM>) {
-                // [floating-point -> integer] - number not exactly representable
-                // https://www.w3.org/TR/WGSL/#floating-point-conversion
-                switch (conv.Failure()) {
-                    case ConversionFailure::kExceedsNegativeLimit:
-                        return builder.create<Scalar<TO>>(target_ty, TO::Lowest());
-                    case ConversionFailure::kExceedsPositiveLimit:
-                        return builder.create<Scalar<TO>>(target_ty, TO::Highest());
-                }
-            } else if constexpr (IsIntegral<FROM>) {
-                // [integer -> integer] - number not exactly representable
-                // Static cast
-                return builder.create<Scalar<TO>>(target_ty, static_cast<TO>(value));
-            }
-            return nullptr;  // Expression is not constant.
-        });
-        TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
-    }
-
     type::Type const* const type;
     const T value;
 };
@@ -373,23 +311,6 @@
     bool AllEqual() const override { return true; }
     size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
 
-    ImplResult Convert(ProgramBuilder& builder,
-                       const type::Type* target_ty,
-                       const Source& source) const override {
-        // Convert the single splatted element type.
-        // Note: This file is the only place where `constant::Constant`s are created, so this
-        // static_cast is safe.
-        auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
-            builder, type::Type::ElementOf(target_ty), source);
-        if (!conv_el) {
-            return utils::Failure;
-        }
-        if (!conv_el.Get()) {
-            return nullptr;
-        }
-        return builder.create<Splat>(target_ty, conv_el.Get(), count);
-    }
-
     type::Type const* const type;
     const constant::Constant* el;
     const size_t count;
@@ -418,41 +339,6 @@
     bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
     size_t Hash() const override { return hash; }
 
-    ImplResult Convert(ProgramBuilder& builder,
-                       const type::Type* target_ty,
-                       const Source& source) const override {
-        // Convert each of the composite element types.
-        utils::Vector<const constant::Constant*, 4> conv_els;
-        conv_els.Reserve(elements.Length());
-        std::function<const type::Type*(size_t idx)> target_el_ty;
-        if (auto* str = target_ty->As<type::Struct>()) {
-            if (str->Members().Length() != elements.Length()) {
-                TINT_ICE(Resolver, builder.Diagnostics())
-                    << "const-eval conversion of structure has mismatched element counts";
-                return utils::Failure;
-            }
-            target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); };
-        } else {
-            auto* el_ty = type::Type::ElementOf(target_ty);
-            target_el_ty = [el_ty](size_t) { return el_ty; };
-        }
-
-        for (auto* el : elements) {
-            // Note: This file is the only place where `constant::Constant`s are created, so the
-            // static_cast is safe.
-            auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
-                builder, target_el_ty(conv_els.Length()), source);
-            if (!conv_el) {
-                return utils::Failure;
-            }
-            if (!conv_el.Get()) {
-                return nullptr;
-            }
-            conv_els.Push(conv_el.Get());
-        }
-        return CreateComposite(builder, target_ty, std::move(conv_els));
-    }
-
     size_t CalcHash() {
         auto h = utils::Hash(type, all_zero, any_zero);
         for (auto* el : elements) {
@@ -468,6 +354,148 @@
     const size_t hash;
 };
 
+template <typename T>
+ImplResult ScalarConvert(const Scalar<T>* scalar,
+                         ProgramBuilder& builder,
+                         const type::Type* target_ty,
+                         const Source& source) {
+    TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
+    if (target_ty == scalar->type) {
+        // If the types are identical, then no conversion is needed.
+        return scalar;
+    }
+    return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
+        // `value` is the source value.
+        // `FROM` is the source type.
+        // `TO` is the target type.
+        using TO = std::decay_t<decltype(zero_to)>;
+        using FROM = T;
+        if constexpr (std::is_same_v<TO, bool>) {
+            // [x -> bool]
+            return builder.create<Scalar<TO>>(target_ty, !IsPositiveZero(scalar->value));
+        } else if constexpr (std::is_same_v<FROM, bool>) {
+            // [bool -> x]
+            return builder.create<Scalar<TO>>(target_ty, TO(scalar->value ? 1 : 0));
+        } else if (auto conv = CheckedConvert<TO>(scalar->value)) {
+            // Conversion success
+            return builder.create<Scalar<TO>>(target_ty, conv.Get());
+            // --- Below this point are the failure cases ---
+        } else if constexpr (IsAbstract<FROM>) {
+            // [abstract-numeric -> x] - materialization failure
+            builder.Diagnostics().add_error(
+                tint::diag::System::Resolver,
+                OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source);
+            return utils::Failure;
+        } else if constexpr (IsFloatingPoint<TO>) {
+            // [x -> floating-point] - number not exactly representable
+            // https://www.w3.org/TR/WGSL/#floating-point-conversion
+            builder.Diagnostics().add_error(
+                tint::diag::System::Resolver,
+                OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source);
+            return utils::Failure;
+        } else if constexpr (IsFloatingPoint<FROM>) {
+            // [floating-point -> integer] - number not exactly representable
+            // https://www.w3.org/TR/WGSL/#floating-point-conversion
+            switch (conv.Failure()) {
+                case ConversionFailure::kExceedsNegativeLimit:
+                    return builder.create<Scalar<TO>>(target_ty, TO::Lowest());
+                case ConversionFailure::kExceedsPositiveLimit:
+                    return builder.create<Scalar<TO>>(target_ty, TO::Highest());
+            }
+        } else if constexpr (IsIntegral<FROM>) {
+            // [integer -> integer] - number not exactly representable
+            // Static cast
+            return builder.create<Scalar<TO>>(target_ty, static_cast<TO>(scalar->value));
+        }
+        return nullptr;  // Expression is not constant.
+    });
+    TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
+}
+
+// Forward declare
+ImplResult ConvertInternal(const constant::Constant* c,
+                           ProgramBuilder& builder,
+                           const type::Type* target_ty,
+                           const Source& source);
+
+ImplResult SplatConvert(const Splat* splat,
+                        ProgramBuilder& builder,
+                        const type::Type* target_ty,
+                        const Source& source) {
+    // Convert the single splatted element type.
+    auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source);
+    if (!conv_el) {
+        return utils::Failure;
+    }
+    if (!conv_el.Get()) {
+        return nullptr;
+    }
+    return builder.create<Splat>(target_ty, conv_el.Get(), splat->count);
+}
+
+ImplResult CompositeConvert(const Composite* composite,
+                            ProgramBuilder& builder,
+                            const type::Type* target_ty,
+                            const Source& source) {
+    // Convert each of the composite element types.
+    utils::Vector<const constant::Constant*, 4> conv_els;
+    conv_els.Reserve(composite->elements.Length());
+
+    std::function<const type::Type*(size_t idx)> target_el_ty;
+    if (auto* str = target_ty->As<type::Struct>()) {
+        if (str->Members().Length() != composite->elements.Length()) {
+            TINT_ICE(Resolver, builder.Diagnostics())
+                << "const-eval conversion of structure has mismatched element counts";
+            return utils::Failure;
+        }
+        target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); };
+    } else {
+        auto* el_ty = type::Type::ElementOf(target_ty);
+        target_el_ty = [el_ty](size_t) { return el_ty; };
+    }
+
+    for (auto* el : composite->elements) {
+        auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source);
+        if (!conv_el) {
+            return utils::Failure;
+        }
+        if (!conv_el.Get()) {
+            return nullptr;
+        }
+        conv_els.Push(conv_el.Get());
+    }
+    return CreateComposite(builder, target_ty, std::move(conv_els));
+}
+
+ImplResult ConvertInternal(const constant::Constant* c,
+                           ProgramBuilder& builder,
+                           const type::Type* target_ty,
+                           const Source& source) {
+    return Switch(
+        c,
+        [&](const Scalar<tint::AFloat>* val) {
+            return ScalarConvert(val, builder, target_ty, source);
+        },
+        [&](const Scalar<tint::AInt>* val) {
+            return ScalarConvert(val, builder, target_ty, source);
+        },
+        [&](const Scalar<tint::u32>* val) {
+            return ScalarConvert(val, builder, target_ty, source);
+        },
+        [&](const Scalar<tint::i32>* val) {
+            return ScalarConvert(val, builder, target_ty, source);
+        },
+        [&](const Scalar<tint::f32>* val) {
+            return ScalarConvert(val, builder, target_ty, source);
+        },
+        [&](const Scalar<tint::f16>* val) {
+            return ScalarConvert(val, builder, target_ty, source);
+        },
+        [&](const Scalar<bool>* val) { return ScalarConvert(val, builder, target_ty, source); },
+        [&](const Splat* val) { return SplatConvert(val, builder, target_ty, source); },
+        [&](const Composite* val) { return CompositeConvert(val, builder, target_ty, source); });
+}
+
 }  // namespace
 }  // namespace tint::resolver
 
@@ -3707,7 +3735,7 @@
     if (value->Type() == target_ty) {
         return value;
     }
-    return static_cast<const ImplConstant*>(value)->Convert(builder, target_ty, source);
+    return ConvertInternal(value, builder, target_ty, source);
 }
 
 void ConstEval::AddError(const std::string& msg, const Source& source) const {