Extract constant convert methods.
This CL pulls the convert methods out into standalone methods inside the
resolver and de-couples from the constants.
Bug: tint:1718
Change-Id: Id566704687b2d74e05eae860477552f88f6a06b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114120
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 94d60d4..d7de6cd 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -253,12 +253,6 @@
public:
ImplConstant() = default;
~ImplConstant() override = default;
-
- /// Convert attempts to convert the constant value to the given type. On error, Convert()
- /// creates a new diagnostic message and returns a Failure.
- virtual utils::Result<const ImplConstant*> Convert(ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source) const = 0;
};
/// A result templated with a ImplConstant.
@@ -297,62 +291,6 @@
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
- ImplResult Convert(ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source) const override {
- TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
- if (target_ty == type) {
- // If the types are identical, then no conversion is needed.
- return this;
- }
- return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
- // `value` is the source value.
- // `FROM` is the source type.
- // `TO` is the target type.
- using TO = std::decay_t<decltype(zero_to)>;
- using FROM = T;
- if constexpr (std::is_same_v<TO, bool>) {
- // [x -> bool]
- return builder.create<Scalar<TO>>(target_ty, !IsPositiveZero(value));
- } else if constexpr (std::is_same_v<FROM, bool>) {
- // [bool -> x]
- return builder.create<Scalar<TO>>(target_ty, TO(value ? 1 : 0));
- } else if (auto conv = CheckedConvert<TO>(value)) {
- // Conversion success
- return builder.create<Scalar<TO>>(target_ty, conv.Get());
- // --- Below this point are the failure cases ---
- } else if constexpr (IsAbstract<FROM>) {
- // [abstract-numeric -> x] - materialization failure
- builder.Diagnostics().add_error(
- tint::diag::System::Resolver,
- OverflowErrorMessage(value, builder.FriendlyName(target_ty)), source);
- return utils::Failure;
- } else if constexpr (IsFloatingPoint<TO>) {
- // [x -> floating-point] - number not exactly representable
- // https://www.w3.org/TR/WGSL/#floating-point-conversion
- builder.Diagnostics().add_error(
- tint::diag::System::Resolver,
- OverflowErrorMessage(value, builder.FriendlyName(target_ty)), source);
- return utils::Failure;
- } else if constexpr (IsFloatingPoint<FROM>) {
- // [floating-point -> integer] - number not exactly representable
- // https://www.w3.org/TR/WGSL/#floating-point-conversion
- switch (conv.Failure()) {
- case ConversionFailure::kExceedsNegativeLimit:
- return builder.create<Scalar<TO>>(target_ty, TO::Lowest());
- case ConversionFailure::kExceedsPositiveLimit:
- return builder.create<Scalar<TO>>(target_ty, TO::Highest());
- }
- } else if constexpr (IsIntegral<FROM>) {
- // [integer -> integer] - number not exactly representable
- // Static cast
- return builder.create<Scalar<TO>>(target_ty, static_cast<TO>(value));
- }
- return nullptr; // Expression is not constant.
- });
- TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
- }
-
type::Type const* const type;
const T value;
};
@@ -373,23 +311,6 @@
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
- ImplResult Convert(ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source) const override {
- // Convert the single splatted element type.
- // Note: This file is the only place where `constant::Constant`s are created, so this
- // static_cast is safe.
- auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
- builder, type::Type::ElementOf(target_ty), source);
- if (!conv_el) {
- return utils::Failure;
- }
- if (!conv_el.Get()) {
- return nullptr;
- }
- return builder.create<Splat>(target_ty, conv_el.Get(), count);
- }
-
type::Type const* const type;
const constant::Constant* el;
const size_t count;
@@ -418,41 +339,6 @@
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
size_t Hash() const override { return hash; }
- ImplResult Convert(ProgramBuilder& builder,
- const type::Type* target_ty,
- const Source& source) const override {
- // Convert each of the composite element types.
- utils::Vector<const constant::Constant*, 4> conv_els;
- conv_els.Reserve(elements.Length());
- std::function<const type::Type*(size_t idx)> target_el_ty;
- if (auto* str = target_ty->As<type::Struct>()) {
- if (str->Members().Length() != elements.Length()) {
- TINT_ICE(Resolver, builder.Diagnostics())
- << "const-eval conversion of structure has mismatched element counts";
- return utils::Failure;
- }
- target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); };
- } else {
- auto* el_ty = type::Type::ElementOf(target_ty);
- target_el_ty = [el_ty](size_t) { return el_ty; };
- }
-
- for (auto* el : elements) {
- // Note: This file is the only place where `constant::Constant`s are created, so the
- // static_cast is safe.
- auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
- builder, target_el_ty(conv_els.Length()), source);
- if (!conv_el) {
- return utils::Failure;
- }
- if (!conv_el.Get()) {
- return nullptr;
- }
- conv_els.Push(conv_el.Get());
- }
- return CreateComposite(builder, target_ty, std::move(conv_els));
- }
-
size_t CalcHash() {
auto h = utils::Hash(type, all_zero, any_zero);
for (auto* el : elements) {
@@ -468,6 +354,148 @@
const size_t hash;
};
+template <typename T>
+ImplResult ScalarConvert(const Scalar<T>* scalar,
+ ProgramBuilder& builder,
+ const type::Type* target_ty,
+ const Source& source) {
+ TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
+ if (target_ty == scalar->type) {
+ // If the types are identical, then no conversion is needed.
+ return scalar;
+ }
+ return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
+ // `value` is the source value.
+ // `FROM` is the source type.
+ // `TO` is the target type.
+ using TO = std::decay_t<decltype(zero_to)>;
+ using FROM = T;
+ if constexpr (std::is_same_v<TO, bool>) {
+ // [x -> bool]
+ return builder.create<Scalar<TO>>(target_ty, !IsPositiveZero(scalar->value));
+ } else if constexpr (std::is_same_v<FROM, bool>) {
+ // [bool -> x]
+ return builder.create<Scalar<TO>>(target_ty, TO(scalar->value ? 1 : 0));
+ } else if (auto conv = CheckedConvert<TO>(scalar->value)) {
+ // Conversion success
+ return builder.create<Scalar<TO>>(target_ty, conv.Get());
+ // --- Below this point are the failure cases ---
+ } else if constexpr (IsAbstract<FROM>) {
+ // [abstract-numeric -> x] - materialization failure
+ builder.Diagnostics().add_error(
+ tint::diag::System::Resolver,
+ OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source);
+ return utils::Failure;
+ } else if constexpr (IsFloatingPoint<TO>) {
+ // [x -> floating-point] - number not exactly representable
+ // https://www.w3.org/TR/WGSL/#floating-point-conversion
+ builder.Diagnostics().add_error(
+ tint::diag::System::Resolver,
+ OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source);
+ return utils::Failure;
+ } else if constexpr (IsFloatingPoint<FROM>) {
+ // [floating-point -> integer] - number not exactly representable
+ // https://www.w3.org/TR/WGSL/#floating-point-conversion
+ switch (conv.Failure()) {
+ case ConversionFailure::kExceedsNegativeLimit:
+ return builder.create<Scalar<TO>>(target_ty, TO::Lowest());
+ case ConversionFailure::kExceedsPositiveLimit:
+ return builder.create<Scalar<TO>>(target_ty, TO::Highest());
+ }
+ } else if constexpr (IsIntegral<FROM>) {
+ // [integer -> integer] - number not exactly representable
+ // Static cast
+ return builder.create<Scalar<TO>>(target_ty, static_cast<TO>(scalar->value));
+ }
+ return nullptr; // Expression is not constant.
+ });
+ TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
+}
+
+// Forward declare
+ImplResult ConvertInternal(const constant::Constant* c,
+ ProgramBuilder& builder,
+ const type::Type* target_ty,
+ const Source& source);
+
+ImplResult SplatConvert(const Splat* splat,
+ ProgramBuilder& builder,
+ const type::Type* target_ty,
+ const Source& source) {
+ // Convert the single splatted element type.
+ auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source);
+ if (!conv_el) {
+ return utils::Failure;
+ }
+ if (!conv_el.Get()) {
+ return nullptr;
+ }
+ return builder.create<Splat>(target_ty, conv_el.Get(), splat->count);
+}
+
+ImplResult CompositeConvert(const Composite* composite,
+ ProgramBuilder& builder,
+ const type::Type* target_ty,
+ const Source& source) {
+ // Convert each of the composite element types.
+ utils::Vector<const constant::Constant*, 4> conv_els;
+ conv_els.Reserve(composite->elements.Length());
+
+ std::function<const type::Type*(size_t idx)> target_el_ty;
+ if (auto* str = target_ty->As<type::Struct>()) {
+ if (str->Members().Length() != composite->elements.Length()) {
+ TINT_ICE(Resolver, builder.Diagnostics())
+ << "const-eval conversion of structure has mismatched element counts";
+ return utils::Failure;
+ }
+ target_el_ty = [str](size_t idx) { return str->Members()[idx]->Type(); };
+ } else {
+ auto* el_ty = type::Type::ElementOf(target_ty);
+ target_el_ty = [el_ty](size_t) { return el_ty; };
+ }
+
+ for (auto* el : composite->elements) {
+ auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source);
+ if (!conv_el) {
+ return utils::Failure;
+ }
+ if (!conv_el.Get()) {
+ return nullptr;
+ }
+ conv_els.Push(conv_el.Get());
+ }
+ return CreateComposite(builder, target_ty, std::move(conv_els));
+}
+
+ImplResult ConvertInternal(const constant::Constant* c,
+ ProgramBuilder& builder,
+ const type::Type* target_ty,
+ const Source& source) {
+ return Switch(
+ c,
+ [&](const Scalar<tint::AFloat>* val) {
+ return ScalarConvert(val, builder, target_ty, source);
+ },
+ [&](const Scalar<tint::AInt>* val) {
+ return ScalarConvert(val, builder, target_ty, source);
+ },
+ [&](const Scalar<tint::u32>* val) {
+ return ScalarConvert(val, builder, target_ty, source);
+ },
+ [&](const Scalar<tint::i32>* val) {
+ return ScalarConvert(val, builder, target_ty, source);
+ },
+ [&](const Scalar<tint::f32>* val) {
+ return ScalarConvert(val, builder, target_ty, source);
+ },
+ [&](const Scalar<tint::f16>* val) {
+ return ScalarConvert(val, builder, target_ty, source);
+ },
+ [&](const Scalar<bool>* val) { return ScalarConvert(val, builder, target_ty, source); },
+ [&](const Splat* val) { return SplatConvert(val, builder, target_ty, source); },
+ [&](const Composite* val) { return CompositeConvert(val, builder, target_ty, source); });
+}
+
} // namespace
} // namespace tint::resolver
@@ -3707,7 +3735,7 @@
if (value->Type() == target_ty) {
return value;
}
- return static_cast<const ImplConstant*>(value)->Convert(builder, target_ty, source);
+ return ConvertInternal(value, builder, target_ty, source);
}
void ConstEval::AddError(const std::string& msg, const Source& source) const {