[tint] Move ZeroValue() to constant::Manager
This makes it accessible to other parts of Tint that need to create
zero-values for arbitrary types.
Change-Id: I845f5d01a92cf51fc3306e559760f75172226359
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/152462
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/core/constant/eval.cc b/src/tint/lang/core/constant/eval.cc
index 4c58f59..74cc427 100644
--- a/src/tint/lang/core/constant/eval.cc
+++ b/src/tint/lang/core/constant/eval.cc
@@ -30,7 +30,6 @@
#include "src/tint/lang/core/number.h"
#include "src/tint/lang/core/type/abstract_float.h"
#include "src/tint/lang/core/type/abstract_int.h"
-#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/bool.h"
#include "src/tint/lang/core/type/f16.h"
#include "src/tint/lang/core/type/f32.h"
@@ -638,7 +637,7 @@
if (!std::isfinite(v.value)) {
AddError(OverflowErrorMessage(v, t->FriendlyName()), source);
if (use_runtime_semantics_) {
- return ZeroValue(t);
+ return mgr.Zero(t);
} else {
return tint::Failure;
}
@@ -647,52 +646,6 @@
return mgr.Get<Scalar<T>>(t, v);
}
-const Value* Eval::ZeroValue(const core::type::Type* type) {
- return Switch(
- type, //
- [&](const core::type::Vector* v) -> const Value* {
- auto* zero_el = ZeroValue(v->type());
- return mgr.Splat(type, zero_el, v->Width());
- },
- [&](const core::type::Matrix* m) -> const Value* {
- auto* zero_el = ZeroValue(m->ColumnType());
- return mgr.Splat(type, zero_el, m->columns());
- },
- [&](const core::type::Array* a) -> const Value* {
- if (auto n = a->ConstantCount()) {
- if (auto* zero_el = ZeroValue(a->ElemType())) {
- return mgr.Splat(type, zero_el, n.value());
- }
- }
- return nullptr;
- },
- [&](const core::type::Struct* s) -> const Value* {
- Hashmap<const core::type::Type*, const Value*, 8> zero_by_type;
- Vector<const Value*, 4> zeros;
- zeros.Reserve(s->Members().Length());
- for (auto* member : s->Members()) {
- auto* zero = zero_by_type.GetOrCreate(member->Type(),
- [&] { return ZeroValue(member->Type()); });
- if (!zero) {
- return nullptr;
- }
- zeros.Push(zero);
- }
- if (zero_by_type.Count() == 1) {
- // All members were of the same type, so the zero value is the same for all members.
- return mgr.Splat(type, zeros[0], s->Members().Length());
- }
- return mgr.Composite(s, std::move(zeros));
- },
- [&](Default) -> const Value* {
- return ZeroTypeDispatch(type, [&](auto zero) -> const Value* {
- auto el = CreateScalar(Source{}, type, zero);
- TINT_ASSERT(el);
- return el.Get();
- });
- });
-}
-
template <typename NumberT>
tint::Result<NumberT> Eval::Add(const Source& source, NumberT a, NumberT b) {
NumberT result;
@@ -1323,7 +1276,7 @@
Eval::Result Eval::ArrayOrStructCtor(const core::type::Type* ty, VectorRef<const Value*> args) {
if (args.IsEmpty()) {
- return ZeroValue(ty);
+ return mgr.Zero(ty);
}
if (args.Length() == 1 && args[0]->Type() == ty) {
@@ -1351,7 +1304,7 @@
}
Eval::Result Eval::Zero(const core::type::Type* ty, VectorRef<const Value*>, const Source&) {
- return ZeroValue(ty);
+ return mgr.Zero(ty);
}
Eval::Result Eval::Identity(const core::type::Type*, VectorRef<const Value*> args, const Source&) {
@@ -1436,7 +1389,7 @@
}
AddError("index " + std::to_string(idx) + " out of bounds" + range, idx_source);
if (use_runtime_semantics_) {
- return ZeroValue(el.type);
+ return mgr.Zero(el.type);
} else {
return tint::Failure;
}
@@ -2178,7 +2131,7 @@
AddError("acos must be called with a value in the range [-1 .. 1] (inclusive)",
source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -2199,7 +2152,7 @@
if (i < NumberT(1.0)) {
AddError("acosh must be called with a value >= 1.0", source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -2234,7 +2187,7 @@
AddError("asin must be called with a value in the range [-1 .. 1] (inclusive)",
source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -2281,7 +2234,7 @@
AddError("atanh must be called with a value in the range (-1 .. 1) (exclusive)",
source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -2551,7 +2504,7 @@
if (!std::isfinite(val.value)) {
AddError(OverflowExpErrorMessage("e", e0), source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -2573,7 +2526,7 @@
if (!std::isfinite(val.value)) {
AddError(OverflowExpErrorMessage("2", e0), source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -2922,7 +2875,7 @@
if (e <= NumberT(0)) {
AddError("inverseSqrt must be called with a value > 0", source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -2979,7 +2932,7 @@
if (e2 > bias + 1) {
AddError("e2 must be less than or equal to " + std::to_string(bias + 1), source);
if (use_runtime_semantics_) {
- return ZeroValue(c1->Type());
+ return mgr.Zero(c1->Type());
} else {
return tint::Failure;
}
@@ -3015,7 +2968,7 @@
if (v <= NumberT(0)) {
AddError("log must be called with a value > 0", source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -3036,7 +2989,7 @@
if (v <= NumberT(0)) {
AddError("log2 must be called with a value > 0", source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -3161,7 +3114,7 @@
if (v->AllZero()) {
AddError("zero length vector can not be normalized", source);
if (use_runtime_semantics_) {
- return ZeroValue(ty);
+ return mgr.Zero(ty);
} else {
return tint::Failure;
}
@@ -3282,7 +3235,7 @@
if (!r) {
AddError(OverflowErrorMessage(e1, "^", e2), source);
if (use_runtime_semantics_) {
- return ZeroValue(c0->Type());
+ return mgr.Zero(c0->Type());
} else {
return tint::Failure;
}
@@ -3436,7 +3389,7 @@
// If k < 0.0, returns the refraction vector 0.0
if (k.Get()->ValueAs<AFloat>() < 0) {
- return ZeroValue(ty);
+ return mgr.Zero(ty);
}
// Otherwise return the refraction vector e3 * e1 - (e3 * dot(e2, e1) + sqrt(k)) * e2
@@ -3862,7 +3815,7 @@
if (!conv) {
AddError(OverflowErrorMessage(value, "f16"), source);
if (use_runtime_semantics_) {
- return ZeroValue(c->Type());
+ return mgr.Zero(c->Type());
} else {
return tint::Failure;
}
diff --git a/src/tint/lang/core/constant/eval.h b/src/tint/lang/core/constant/eval.h
index 035f353..17f298e 100644
--- a/src/tint/lang/core/constant/eval.h
+++ b/src/tint/lang/core/constant/eval.h
@@ -953,9 +953,6 @@
template <typename T>
Eval::Result CreateScalar(const Source& source, const core::type::Type* t, T v);
- /// ZeroValue returns a Constant for the zero-value of the type `type`.
- const Value* ZeroValue(const core::type::Type* type);
-
/// Adds two Number<T>s
/// @param source the source location
/// @param a the lhs number
diff --git a/src/tint/lang/core/constant/manager.cc b/src/tint/lang/core/constant/manager.cc
index 0a23ab4..22ee8ee 100644
--- a/src/tint/lang/core/constant/manager.cc
+++ b/src/tint/lang/core/constant/manager.cc
@@ -19,13 +19,17 @@
#include "src/tint/lang/core/constant/splat.h"
#include "src/tint/lang/core/type/abstract_float.h"
#include "src/tint/lang/core/type/abstract_int.h"
+#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/bool.h"
#include "src/tint/lang/core/type/f16.h"
#include "src/tint/lang/core/type/f32.h"
#include "src/tint/lang/core/type/i32.h"
#include "src/tint/lang/core/type/manager.h"
+#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/core/type/u32.h"
+#include "src/tint/lang/core/type/vector.h"
#include "src/tint/utils/containers/predicates.h"
+#include "src/tint/utils/rtti/switch.h"
namespace tint::core::constant {
@@ -102,4 +106,54 @@
return Get<Scalar<AInt>>(types.AInt(), value);
}
+const Value* Manager::Zero(const core::type::Type* type) {
+ return Switch(
+ type, //
+ [&](const core::type::Vector* v) -> const Value* {
+ auto* zero_el = Zero(v->type());
+ return Splat(type, zero_el, v->Width());
+ },
+ [&](const core::type::Matrix* m) -> const Value* {
+ auto* zero_el = Zero(m->ColumnType());
+ return Splat(type, zero_el, m->columns());
+ },
+ [&](const core::type::Array* a) -> const Value* {
+ if (auto n = a->ConstantCount()) {
+ if (auto* zero_el = Zero(a->ElemType())) {
+ return Splat(type, zero_el, n.value());
+ }
+ }
+ return nullptr;
+ },
+ [&](const core::type::Struct* s) -> const Value* {
+ Hashmap<const core::type::Type*, const Value*, 8> zero_by_type;
+ Vector<const Value*, 4> zeros;
+ zeros.Reserve(s->Members().Length());
+ for (auto* member : s->Members()) {
+ auto* zero =
+ zero_by_type.GetOrCreate(member->Type(), [&] { return Zero(member->Type()); });
+ if (!zero) {
+ return nullptr;
+ }
+ zeros.Push(zero);
+ }
+ if (zero_by_type.Count() == 1) {
+ // All members were of the same type, so the zero value is the same for all members.
+ return Splat(type, zeros[0], s->Members().Length());
+ }
+ return Composite(s, std::move(zeros));
+ },
+ [&](Default) -> const Value* {
+ return Switch(
+ type, //
+ [&](const core::type::AbstractInt*) { return Get(AInt(0)); }, //
+ [&](const core::type::AbstractFloat*) { return Get(AFloat(0)); }, //
+ [&](const core::type::I32*) { return Get(i32(0)); }, //
+ [&](const core::type::U32*) { return Get(u32(0)); }, //
+ [&](const core::type::F32*) { return Get(f32(0)); }, //
+ [&](const core::type::F16*) { return Get(f16(0)); }, //
+ [&](const core::type::Bool*) { return Get(false); });
+ });
+}
+
} // namespace tint::core::constant
diff --git a/src/tint/lang/core/constant/manager.h b/src/tint/lang/core/constant/manager.h
index 3cc27be..dac8115 100644
--- a/src/tint/lang/core/constant/manager.h
+++ b/src/tint/lang/core/constant/manager.h
@@ -129,6 +129,11 @@
/// @return a Scalar holding the AInt value @p value
const Scalar<AInt>* Get(AInt value);
+ /// Constructs a constant zero-value of the type @p type.
+ /// @param type the constant type
+ /// @returns a constant zero-value for the type
+ const Value* Zero(const core::type::Type* type);
+
/// The type manager
core::type::Manager types;