[tint] Move ZeroValue() to constant::Manager

This makes it accessible to other parts of Tint that need to create
zero-values for arbitrary types.

Change-Id: I845f5d01a92cf51fc3306e559760f75172226359
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/152462
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/constant/eval.cc b/src/tint/lang/core/constant/eval.cc
index 4c58f59..74cc427 100644
--- a/src/tint/lang/core/constant/eval.cc
+++ b/src/tint/lang/core/constant/eval.cc
@@ -30,7 +30,6 @@
 #include "src/tint/lang/core/number.h"
 #include "src/tint/lang/core/type/abstract_float.h"
 #include "src/tint/lang/core/type/abstract_int.h"
-#include "src/tint/lang/core/type/array.h"
 #include "src/tint/lang/core/type/bool.h"
 #include "src/tint/lang/core/type/f16.h"
 #include "src/tint/lang/core/type/f32.h"
@@ -638,7 +637,7 @@
         if (!std::isfinite(v.value)) {
             AddError(OverflowErrorMessage(v, t->FriendlyName()), source);
             if (use_runtime_semantics_) {
-                return ZeroValue(t);
+                return mgr.Zero(t);
             } else {
                 return tint::Failure;
             }
@@ -647,52 +646,6 @@
     return mgr.Get<Scalar<T>>(t, v);
 }
 
-const Value* Eval::ZeroValue(const core::type::Type* type) {
-    return Switch(
-        type,  //
-        [&](const core::type::Vector* v) -> const Value* {
-            auto* zero_el = ZeroValue(v->type());
-            return mgr.Splat(type, zero_el, v->Width());
-        },
-        [&](const core::type::Matrix* m) -> const Value* {
-            auto* zero_el = ZeroValue(m->ColumnType());
-            return mgr.Splat(type, zero_el, m->columns());
-        },
-        [&](const core::type::Array* a) -> const Value* {
-            if (auto n = a->ConstantCount()) {
-                if (auto* zero_el = ZeroValue(a->ElemType())) {
-                    return mgr.Splat(type, zero_el, n.value());
-                }
-            }
-            return nullptr;
-        },
-        [&](const core::type::Struct* s) -> const Value* {
-            Hashmap<const core::type::Type*, const Value*, 8> zero_by_type;
-            Vector<const Value*, 4> zeros;
-            zeros.Reserve(s->Members().Length());
-            for (auto* member : s->Members()) {
-                auto* zero = zero_by_type.GetOrCreate(member->Type(),
-                                                      [&] { return ZeroValue(member->Type()); });
-                if (!zero) {
-                    return nullptr;
-                }
-                zeros.Push(zero);
-            }
-            if (zero_by_type.Count() == 1) {
-                // All members were of the same type, so the zero value is the same for all members.
-                return mgr.Splat(type, zeros[0], s->Members().Length());
-            }
-            return mgr.Composite(s, std::move(zeros));
-        },
-        [&](Default) -> const Value* {
-            return ZeroTypeDispatch(type, [&](auto zero) -> const Value* {
-                auto el = CreateScalar(Source{}, type, zero);
-                TINT_ASSERT(el);
-                return el.Get();
-            });
-        });
-}
-
 template <typename NumberT>
 tint::Result<NumberT> Eval::Add(const Source& source, NumberT a, NumberT b) {
     NumberT result;
@@ -1323,7 +1276,7 @@
 
 Eval::Result Eval::ArrayOrStructCtor(const core::type::Type* ty, VectorRef<const Value*> args) {
     if (args.IsEmpty()) {
-        return ZeroValue(ty);
+        return mgr.Zero(ty);
     }
 
     if (args.Length() == 1 && args[0]->Type() == ty) {
@@ -1351,7 +1304,7 @@
 }
 
 Eval::Result Eval::Zero(const core::type::Type* ty, VectorRef<const Value*>, const Source&) {
-    return ZeroValue(ty);
+    return mgr.Zero(ty);
 }
 
 Eval::Result Eval::Identity(const core::type::Type*, VectorRef<const Value*> args, const Source&) {
@@ -1436,7 +1389,7 @@
         }
         AddError("index " + std::to_string(idx) + " out of bounds" + range, idx_source);
         if (use_runtime_semantics_) {
-            return ZeroValue(el.type);
+            return mgr.Zero(el.type);
         } else {
             return tint::Failure;
         }
@@ -2178,7 +2131,7 @@
                 AddError("acos must be called with a value in the range [-1 .. 1] (inclusive)",
                          source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -2199,7 +2152,7 @@
             if (i < NumberT(1.0)) {
                 AddError("acosh must be called with a value >= 1.0", source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -2234,7 +2187,7 @@
                 AddError("asin must be called with a value in the range [-1 .. 1] (inclusive)",
                          source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -2281,7 +2234,7 @@
                 AddError("atanh must be called with a value in the range (-1 .. 1) (exclusive)",
                          source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -2551,7 +2504,7 @@
             if (!std::isfinite(val.value)) {
                 AddError(OverflowExpErrorMessage("e", e0), source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -2573,7 +2526,7 @@
             if (!std::isfinite(val.value)) {
                 AddError(OverflowExpErrorMessage("2", e0), source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -2922,7 +2875,7 @@
             if (e <= NumberT(0)) {
                 AddError("inverseSqrt must be called with a value > 0", source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -2979,7 +2932,7 @@
             if (e2 > bias + 1) {
                 AddError("e2 must be less than or equal to " + std::to_string(bias + 1), source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c1->Type());
+                    return mgr.Zero(c1->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -3015,7 +2968,7 @@
             if (v <= NumberT(0)) {
                 AddError("log must be called with a value > 0", source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -3036,7 +2989,7 @@
             if (v <= NumberT(0)) {
                 AddError("log2 must be called with a value > 0", source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -3161,7 +3114,7 @@
     if (v->AllZero()) {
         AddError("zero length vector can not be normalized", source);
         if (use_runtime_semantics_) {
-            return ZeroValue(ty);
+            return mgr.Zero(ty);
         } else {
             return tint::Failure;
         }
@@ -3282,7 +3235,7 @@
             if (!r) {
                 AddError(OverflowErrorMessage(e1, "^", e2), source);
                 if (use_runtime_semantics_) {
-                    return ZeroValue(c0->Type());
+                    return mgr.Zero(c0->Type());
                 } else {
                     return tint::Failure;
                 }
@@ -3436,7 +3389,7 @@
 
         // If k < 0.0, returns the refraction vector 0.0
         if (k.Get()->ValueAs<AFloat>() < 0) {
-            return ZeroValue(ty);
+            return mgr.Zero(ty);
         }
 
         // Otherwise return the refraction vector e3 * e1 - (e3 * dot(e2, e1) + sqrt(k)) * e2
@@ -3862,7 +3815,7 @@
         if (!conv) {
             AddError(OverflowErrorMessage(value, "f16"), source);
             if (use_runtime_semantics_) {
-                return ZeroValue(c->Type());
+                return mgr.Zero(c->Type());
             } else {
                 return tint::Failure;
             }
diff --git a/src/tint/lang/core/constant/eval.h b/src/tint/lang/core/constant/eval.h
index 035f353..17f298e 100644
--- a/src/tint/lang/core/constant/eval.h
+++ b/src/tint/lang/core/constant/eval.h
@@ -953,9 +953,6 @@
     template <typename T>
     Eval::Result CreateScalar(const Source& source, const core::type::Type* t, T v);
 
-    /// ZeroValue returns a Constant for the zero-value of the type `type`.
-    const Value* ZeroValue(const core::type::Type* type);
-
     /// Adds two Number<T>s
     /// @param source the source location
     /// @param a the lhs number
diff --git a/src/tint/lang/core/constant/manager.cc b/src/tint/lang/core/constant/manager.cc
index 0a23ab4..22ee8ee 100644
--- a/src/tint/lang/core/constant/manager.cc
+++ b/src/tint/lang/core/constant/manager.cc
@@ -19,13 +19,17 @@
 #include "src/tint/lang/core/constant/splat.h"
 #include "src/tint/lang/core/type/abstract_float.h"
 #include "src/tint/lang/core/type/abstract_int.h"
+#include "src/tint/lang/core/type/array.h"
 #include "src/tint/lang/core/type/bool.h"
 #include "src/tint/lang/core/type/f16.h"
 #include "src/tint/lang/core/type/f32.h"
 #include "src/tint/lang/core/type/i32.h"
 #include "src/tint/lang/core/type/manager.h"
+#include "src/tint/lang/core/type/matrix.h"
 #include "src/tint/lang/core/type/u32.h"
+#include "src/tint/lang/core/type/vector.h"
 #include "src/tint/utils/containers/predicates.h"
+#include "src/tint/utils/rtti/switch.h"
 
 namespace tint::core::constant {
 
@@ -102,4 +106,54 @@
     return Get<Scalar<AInt>>(types.AInt(), value);
 }
 
+const Value* Manager::Zero(const core::type::Type* type) {
+    return Switch(
+        type,  //
+        [&](const core::type::Vector* v) -> const Value* {
+            auto* zero_el = Zero(v->type());
+            return Splat(type, zero_el, v->Width());
+        },
+        [&](const core::type::Matrix* m) -> const Value* {
+            auto* zero_el = Zero(m->ColumnType());
+            return Splat(type, zero_el, m->columns());
+        },
+        [&](const core::type::Array* a) -> const Value* {
+            if (auto n = a->ConstantCount()) {
+                if (auto* zero_el = Zero(a->ElemType())) {
+                    return Splat(type, zero_el, n.value());
+                }
+            }
+            return nullptr;
+        },
+        [&](const core::type::Struct* s) -> const Value* {
+            Hashmap<const core::type::Type*, const Value*, 8> zero_by_type;
+            Vector<const Value*, 4> zeros;
+            zeros.Reserve(s->Members().Length());
+            for (auto* member : s->Members()) {
+                auto* zero =
+                    zero_by_type.GetOrCreate(member->Type(), [&] { return Zero(member->Type()); });
+                if (!zero) {
+                    return nullptr;
+                }
+                zeros.Push(zero);
+            }
+            if (zero_by_type.Count() == 1) {
+                // All members were of the same type, so the zero value is the same for all members.
+                return Splat(type, zeros[0], s->Members().Length());
+            }
+            return Composite(s, std::move(zeros));
+        },
+        [&](Default) -> const Value* {
+            return Switch(
+                type,                                                              //
+                [&](const core::type::AbstractInt*) { return Get(AInt(0)); },      //
+                [&](const core::type::AbstractFloat*) { return Get(AFloat(0)); },  //
+                [&](const core::type::I32*) { return Get(i32(0)); },               //
+                [&](const core::type::U32*) { return Get(u32(0)); },               //
+                [&](const core::type::F32*) { return Get(f32(0)); },               //
+                [&](const core::type::F16*) { return Get(f16(0)); },               //
+                [&](const core::type::Bool*) { return Get(false); });
+        });
+}
+
 }  // namespace tint::core::constant
diff --git a/src/tint/lang/core/constant/manager.h b/src/tint/lang/core/constant/manager.h
index 3cc27be..dac8115 100644
--- a/src/tint/lang/core/constant/manager.h
+++ b/src/tint/lang/core/constant/manager.h
@@ -129,6 +129,11 @@
     /// @return a Scalar holding the AInt value @p value
     const Scalar<AInt>* Get(AInt value);
 
+    /// Constructs a constant zero-value of the type @p type.
+    /// @param type the constant type
+    /// @returns a constant zero-value for the type
+    const Value* Zero(const core::type::Type* type);
+
     /// The type manager
     core::type::Manager types;