| // Copyright 2022 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/tint/resolver/const_eval.h" |
| |
| #include <algorithm> |
| #include <iomanip> |
| #include <limits> |
| #include <optional> |
| #include <string> |
| #include <type_traits> |
| #include <utility> |
| |
| #include "src/tint/builtin/number.h" |
| #include "src/tint/constant/composite.h" |
| #include "src/tint/constant/scalar.h" |
| #include "src/tint/constant/splat.h" |
| #include "src/tint/constant/value.h" |
| #include "src/tint/program_builder.h" |
| #include "src/tint/sem/member_accessor_expression.h" |
| #include "src/tint/sem/value_constructor.h" |
| #include "src/tint/switch.h" |
| #include "src/tint/type/abstract_float.h" |
| #include "src/tint/type/abstract_int.h" |
| #include "src/tint/type/array.h" |
| #include "src/tint/type/bool.h" |
| #include "src/tint/type/f16.h" |
| #include "src/tint/type/f32.h" |
| #include "src/tint/type/i32.h" |
| #include "src/tint/type/matrix.h" |
| #include "src/tint/type/struct.h" |
| #include "src/tint/type/u32.h" |
| #include "src/tint/type/vector.h" |
| #include "src/tint/utils/bitcast.h" |
| #include "src/tint/utils/compiler_macros.h" |
| #include "src/tint/utils/map.h" |
| #include "src/tint/utils/string_stream.h" |
| #include "src/tint/utils/transform.h" |
| |
| using namespace tint::number_suffixes; // NOLINT |
| |
| namespace tint::resolver { |
| |
| namespace { |
| |
| /// Returns the first element of a parameter pack |
| template <typename T> |
| T First(T&& first, ...) { |
| return std::forward<T>(first); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_iu32(F&& f, CONSTANTS&&... cs) { |
| return Switch( |
| First(cs...)->Type(), // |
| [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); }, |
| [&](const type::U32*) { return f(cs->template ValueAs<u32>()...); }); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_fiu32(F&& f, CONSTANTS&&... cs) { |
| return Switch( |
| First(cs...)->Type(), // |
| [&](const type::F32*) { return f(cs->template ValueAs<f32>()...); }, |
| [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); }, |
| [&](const type::U32*) { return f(cs->template ValueAs<u32>()...); }); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_ia_iu32(F&& f, CONSTANTS&&... cs) { |
| return Switch( |
| First(cs...)->Type(), // |
| [&](const type::AbstractInt*) { return f(cs->template ValueAs<AInt>()...); }, |
| [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); }, |
| [&](const type::U32*) { return f(cs->template ValueAs<u32>()...); }); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_ia_iu32_bool(F&& f, CONSTANTS&&... cs) { |
| return Switch( |
| First(cs...)->Type(), // |
| [&](const type::AbstractInt*) { return f(cs->template ValueAs<AInt>()...); }, |
| [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); }, |
| [&](const type::U32*) { return f(cs->template ValueAs<u32>()...); }, |
| [&](const type::Bool*) { return f(cs->template ValueAs<bool>()...); }); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_fia_fi32_f16(F&& f, CONSTANTS&&... cs) { |
| return Switch( |
| First(cs...)->Type(), // |
| [&](const type::AbstractInt*) { return f(cs->template ValueAs<AInt>()...); }, |
| [&](const type::AbstractFloat*) { return f(cs->template ValueAs<AFloat>()...); }, |
| [&](const type::F32*) { return f(cs->template ValueAs<f32>()...); }, |
| [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); }, |
| [&](const type::F16*) { return f(cs->template ValueAs<f16>()...); }); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_fia_fiu32_f16(F&& f, CONSTANTS&&... cs) { |
| return Switch( |
| First(cs...)->Type(), // |
| [&](const type::AbstractInt*) { return f(cs->template ValueAs<AInt>()...); }, |
| [&](const type::AbstractFloat*) { return f(cs->template ValueAs<AFloat>()...); }, |
| [&](const type::F32*) { return f(cs->template ValueAs<f32>()...); }, |
| [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); }, |
| [&](const type::U32*) { return f(cs->template ValueAs<u32>()...); }, |
| [&](const type::F16*) { return f(cs->template ValueAs<f16>()...); }); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_fia_fiu32_f16_bool(F&& f, CONSTANTS&&... cs) { |
| return Switch( |
| First(cs...)->Type(), // |
| [&](const type::AbstractInt*) { return f(cs->template ValueAs<AInt>()...); }, |
| [&](const type::AbstractFloat*) { return f(cs->template ValueAs<AFloat>()...); }, |
| [&](const type::F32*) { return f(cs->template ValueAs<f32>()...); }, |
| [&](const type::I32*) { return f(cs->template ValueAs<i32>()...); }, |
| [&](const type::U32*) { return f(cs->template ValueAs<u32>()...); }, |
| [&](const type::F16*) { return f(cs->template ValueAs<f16>()...); }, |
| [&](const type::Bool*) { return f(cs->template ValueAs<bool>()...); }); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_fa_f32_f16(F&& f, CONSTANTS&&... cs) { |
| return Switch( |
| First(cs...)->Type(), // |
| [&](const type::AbstractFloat*) { return f(cs->template ValueAs<AFloat>()...); }, |
| [&](const type::F32*) { return f(cs->template ValueAs<f32>()...); }, |
| [&](const type::F16*) { return f(cs->template ValueAs<f16>()...); }); |
| } |
| |
| /// Helper that calls `f` passing in the value of all `cs`. |
| /// Calls `f` with all constants cast to the type of the first `cs` argument. |
| template <typename F, typename... CONSTANTS> |
| auto Dispatch_bool(F&& f, CONSTANTS&&... cs) { |
| return f(cs->template ValueAs<bool>()...); |
| } |
| |
| /// ZeroTypeDispatch is a helper for calling the function `f`, passing a single zero-value argument |
| /// of the C++ type that corresponds to the type::Type `type`. For example, calling |
| /// `ZeroTypeDispatch()` with a type of `type::I32*` will call the function f with a single argument |
| /// of `i32(0)`. |
| /// @returns the value returned by calling `f`. |
| /// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will |
| /// return the zero-initialized value of the return type for `f`. |
| template <typename F> |
| auto ZeroTypeDispatch(const type::Type* type, F&& f) { |
| return Switch( |
| type, // |
| [&](const type::AbstractInt*) { return f(AInt(0)); }, // |
| [&](const type::AbstractFloat*) { return f(AFloat(0)); }, // |
| [&](const type::I32*) { return f(i32(0)); }, // |
| [&](const type::U32*) { return f(u32(0)); }, // |
| [&](const type::F32*) { return f(f32(0)); }, // |
| [&](const type::F16*) { return f(f16(0)); }, // |
| [&](const type::Bool*) { return f(static_cast<bool>(0)); }); |
| } |
| |
| template <typename NumberT> |
| std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) { |
| utils::StringStream ss; |
| ss << "'" << lhs.value << " " << op << " " << rhs.value << "' cannot be represented as '" |
| << FriendlyName<NumberT>() << "'"; |
| return ss.str(); |
| } |
| |
| template <typename VALUE_TY> |
| std::string OverflowErrorMessage(VALUE_TY value, std::string_view target_ty) { |
| utils::StringStream ss; |
| ss << "value " << value << " cannot be represented as " |
| << "'" << target_ty << "'"; |
| return ss.str(); |
| } |
| |
| template <typename NumberT> |
| std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) { |
| utils::StringStream ss; |
| ss << base << "^" << exp << " cannot be represented as " |
| << "'" << FriendlyName<NumberT>() << "'"; |
| return ss.str(); |
| } |
| |
| /// @returns the number of consecutive leading bits in `@p e` set to `@p bit_value_to_count`. |
| template <typename T> |
| std::make_unsigned_t<T> CountLeadingBits(T e, T bit_value_to_count) { |
| using UT = std::make_unsigned_t<T>; |
| constexpr UT kNumBits = sizeof(UT) * 8; |
| constexpr UT kLeftMost = UT{1} << (kNumBits - 1); |
| const UT b = bit_value_to_count == 0 ? UT{0} : kLeftMost; |
| |
| auto v = static_cast<UT>(e); |
| auto count = UT{0}; |
| while ((count < kNumBits) && ((v & kLeftMost) == b)) { |
| ++count; |
| v <<= 1; |
| } |
| return count; |
| } |
| |
| /// @returns the number of consecutive trailing bits set to `@p bit_value_to_count` in `@p e` |
| template <typename T> |
| std::make_unsigned_t<T> CountTrailingBits(T e, T bit_value_to_count) { |
| using UT = std::make_unsigned_t<T>; |
| constexpr UT kNumBits = sizeof(UT) * 8; |
| constexpr UT kRightMost = UT{1}; |
| const UT b = static_cast<UT>(bit_value_to_count); |
| |
| auto v = static_cast<UT>(e); |
| auto count = UT{0}; |
| while ((count < kNumBits) && ((v & kRightMost) == b)) { |
| ++count; |
| v >>= 1; |
| } |
| return count; |
| } |
| |
| /// Common data for constant conversion. |
| struct ConvertContext { |
| ProgramBuilder& builder; |
| const Source& source; |
| bool use_runtime_semantics; |
| }; |
| |
| /// Converts the constant scalar value to the target type. |
| /// @returns the converted scalar, or nullptr on error. |
| template <typename T> |
| const constant::ScalarBase* ScalarConvert(const constant::Scalar<T>* scalar, |
| const type::Type* target_ty, |
| ConvertContext& ctx) { |
| TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); |
| if (target_ty == scalar->type) { |
| // If the types are identical, then no conversion is needed. |
| return scalar; |
| } |
| return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> const constant::ScalarBase* { |
| // `value` is the source value. |
| // `FROM` is the source type. |
| // `TO` is the target type. |
| using TO = std::decay_t<decltype(zero_to)>; |
| using FROM = T; |
| if constexpr (std::is_same_v<TO, bool>) { |
| // [x -> bool] |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, |
| !scalar->IsPositiveZero()); |
| } else if constexpr (std::is_same_v<FROM, bool>) { |
| // [bool -> x] |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, |
| TO(scalar->value ? 1 : 0)); |
| } else if (auto conv = CheckedConvert<TO>(scalar->value)) { |
| // Conversion success |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, conv.Get()); |
| // --- Below this point are the failure cases --- |
| } else if constexpr (IsAbstract<FROM>) { |
| // [abstract-numeric -> x] - materialization failure |
| auto msg = OverflowErrorMessage(scalar->value, target_ty->FriendlyName()); |
| if (ctx.use_runtime_semantics) { |
| ctx.builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, |
| ctx.source); |
| switch (conv.Failure()) { |
| case ConversionFailure::kExceedsNegativeLimit: |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, |
| TO::Lowest()); |
| case ConversionFailure::kExceedsPositiveLimit: |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, |
| TO::Highest()); |
| } |
| } else { |
| ctx.builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, ctx.source); |
| return nullptr; |
| } |
| } else if constexpr (IsFloatingPoint<TO>) { |
| // [x -> floating-point] - number not exactly representable |
| // https://www.w3.org/TR/WGSL/#floating-point-conversion |
| auto msg = OverflowErrorMessage(scalar->value, target_ty->FriendlyName()); |
| if (ctx.use_runtime_semantics) { |
| ctx.builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, |
| ctx.source); |
| switch (conv.Failure()) { |
| case ConversionFailure::kExceedsNegativeLimit: |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, |
| TO::Lowest()); |
| case ConversionFailure::kExceedsPositiveLimit: |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, |
| TO::Highest()); |
| } |
| } else { |
| ctx.builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, ctx.source); |
| return nullptr; |
| } |
| } else if constexpr (IsFloatingPoint<FROM>) { |
| // [floating-point -> integer] - number not exactly representable |
| // https://www.w3.org/TR/WGSL/#floating-point-conversion |
| switch (conv.Failure()) { |
| case ConversionFailure::kExceedsNegativeLimit: |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, TO::Lowest()); |
| case ConversionFailure::kExceedsPositiveLimit: |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, |
| TO::Highest()); |
| } |
| } else if constexpr (IsIntegral<FROM>) { |
| // [integer -> integer] - number not exactly representable |
| // Static cast |
| return ctx.builder.constants.Get<constant::Scalar<TO>>(target_ty, |
| static_cast<TO>(scalar->value)); |
| } |
| TINT_UNREACHABLE(Resolver, ctx.builder.Diagnostics()) << "Expression is not constant"; |
| return nullptr; |
| }); |
| TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); |
| } |
| |
| /// Converts the constant value to the target type. |
| /// @returns the converted value, or nullptr on error. |
| const constant::Value* ConvertInternal(const constant::Value* root_value, |
| const type::Type* root_target_ty, |
| ConvertContext& ctx) { |
| struct ActionConvert { |
| const constant::Value* value = nullptr; |
| const type::Type* target_ty = nullptr; |
| }; |
| struct ActionBuildSplat { |
| size_t count = 0; |
| const type::Type* type = nullptr; |
| }; |
| struct ActionBuildComposite { |
| size_t count = 0; |
| const type::Type* type = nullptr; |
| }; |
| using Action = std::variant<ActionConvert, ActionBuildSplat, ActionBuildComposite>; |
| |
| utils::Vector<Action, 8> pending{ |
| ActionConvert{root_value, root_target_ty}, |
| }; |
| |
| utils::Vector<const constant::Value*, 32> value_stack; |
| |
| while (!pending.IsEmpty()) { |
| auto next = pending.Pop(); |
| |
| if (auto* build = std::get_if<ActionBuildSplat>(&next)) { |
| TINT_ASSERT(Resolver, value_stack.Length() >= 1); |
| auto* el = value_stack.Pop(); |
| value_stack.Push(ctx.builder.constants.Splat(build->type, el, build->count)); |
| continue; |
| } |
| |
| if (auto* build = std::get_if<ActionBuildComposite>(&next)) { |
| TINT_ASSERT(Resolver, value_stack.Length() >= build->count); |
| // Take build->count elements off the top of value_stack |
| // Note: The values are ordered with the first composite value at the top of the stack. |
| utils::Vector<const constant::Value*, 32> elements; |
| elements.Reserve(build->count); |
| for (size_t i = 0; i < build->count; i++) { |
| elements.Push(value_stack.Pop()); |
| } |
| // Build the composite |
| value_stack.Push(ctx.builder.constants.Composite(build->type, std::move(elements))); |
| continue; |
| } |
| |
| auto* convert = std::get_if<ActionConvert>(&next); |
| |
| bool ok = Switch( |
| convert->value, |
| [&](const constant::ScalarBase* scalar) { |
| auto* converted = Switch( |
| scalar, |
| [&](const constant::Scalar<tint::AFloat>* val) { |
| return ScalarConvert(val, convert->target_ty, ctx); |
| }, |
| [&](const constant::Scalar<tint::AInt>* val) { |
| return ScalarConvert(val, convert->target_ty, ctx); |
| }, |
| [&](const constant::Scalar<tint::u32>* val) { |
| return ScalarConvert(val, convert->target_ty, ctx); |
| }, |
| [&](const constant::Scalar<tint::i32>* val) { |
| return ScalarConvert(val, convert->target_ty, ctx); |
| }, |
| [&](const constant::Scalar<tint::f32>* val) { |
| return ScalarConvert(val, convert->target_ty, ctx); |
| }, |
| [&](const constant::Scalar<tint::f16>* val) { |
| return ScalarConvert(val, convert->target_ty, ctx); |
| }, |
| [&](const constant::Scalar<bool>* val) { |
| return ScalarConvert(val, convert->target_ty, ctx); |
| }); |
| if (!converted) { |
| return false; |
| } |
| value_stack.Push(converted); |
| return true; |
| }, |
| [&](const constant::Splat* splat) { |
| const type::Type* target_el_ty = nullptr; |
| if (auto* str = convert->target_ty->As<type::Struct>()) { |
| // Structure conversion. |
| auto members = str->Members(); |
| target_el_ty = members[0]->Type(); |
| |
| // Structures can only be converted during materialization. The user cannot |
| // declare the target structure type, so each member type must be the same |
| // default materialization type. |
| for (size_t i = 1; i < members.Length(); i++) { |
| if (members[i]->Type() != target_el_ty) { |
| TINT_ICE(Resolver, ctx.builder.Diagnostics()) |
| << "inconsistent target struct member types for SplatConvert"; |
| return false; |
| } |
| } |
| } else { |
| target_el_ty = convert->target_ty->Elements(convert->target_ty).type; |
| } |
| |
| // Convert the single splatted element type. |
| pending.Push(ActionBuildSplat{splat->count, convert->target_ty}); |
| pending.Push(ActionConvert{splat->el, target_el_ty}); |
| return true; |
| }, |
| [&](const constant::Composite* composite) { |
| const size_t el_count = composite->NumElements(); |
| |
| // Build the new composite from the converted element types. |
| pending.Push(ActionBuildComposite{el_count, convert->target_ty}); |
| |
| if (auto* str = convert->target_ty->As<type::Struct>()) { |
| if (TINT_UNLIKELY(str->Members().Length() != el_count)) { |
| TINT_ICE(Resolver, ctx.builder.Diagnostics()) |
| << "const-eval conversion of structure has mismatched element counts"; |
| return false; |
| } |
| // Struct composites can have different types for each member. |
| auto members = str->Members(); |
| for (size_t i = 0; i < el_count; i++) { |
| pending.Push(ActionConvert{composite->Index(i), members[i]->Type()}); |
| } |
| } else { |
| // Non-struct composites have the same type for all elements. |
| auto* el_ty = convert->target_ty->Elements(convert->target_ty).type; |
| for (size_t i = 0; i < el_count; i++) { |
| auto* el = composite->Index(i); |
| pending.Push(ActionConvert{el, el_ty}); |
| } |
| } |
| |
| return true; |
| }); |
| if (!ok) { |
| return nullptr; |
| } |
| } |
| |
| TINT_ASSERT(Resolver, value_stack.Length() == 1); |
| return value_stack.Pop(); |
| } |
| |
| namespace detail { |
| /// Implementation of TransformElements |
| template <typename F, typename... CONSTANTS> |
| ConstEval::Result TransformElements(ProgramBuilder& builder, |
| const type::Type* composite_ty, |
| F&& f, |
| size_t index, |
| CONSTANTS&&... cs) { |
| auto [el_ty, n] = First(cs...)->Type()->Elements(); |
| if (!el_ty) { |
| constexpr bool kHasIndexParam = |
| utils::traits::IsType<size_t, utils::traits::LastParameterType<F>>; |
| if constexpr (kHasIndexParam) { |
| return f(cs..., index); |
| } else { |
| return f(cs...); |
| } |
| } |
| |
| auto* composite_el_ty = composite_ty->Elements(composite_ty).type; |
| |
| utils::Vector<const constant::Value*, 8> els; |
| els.Reserve(n); |
| for (uint32_t i = 0; i < n; i++) { |
| if (auto el = detail::TransformElements(builder, composite_el_ty, std::forward<F>(f), |
| index + i, cs->Index(i)...)) { |
| els.Push(el.Get()); |
| |
| } else { |
| return el.Failure(); |
| } |
| } |
| return builder.constants.Composite(composite_ty, std::move(els)); |
| } |
| } // namespace detail |
| |
| /// TransformElements constructs a new constant of type `composite_ty` by applying the |
| /// transformation function `f` on each of the most deeply nested elements of 'cs'. Assumes that all |
| /// input constants `cs` are of the same arity (all scalars or all vectors of the same size). |
| /// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside |
| /// the most deeply nested aggregate type will be passed in. |
| template <typename F, typename... CONSTANTS> |
| ConstEval::Result TransformElements(ProgramBuilder& builder, |
| const type::Type* composite_ty, |
| F&& f, |
| CONSTANTS&&... cs) { |
| return detail::TransformElements(builder, composite_ty, f, 0, cs...); |
| } |
| |
| /// TransformBinaryElements constructs a new constant of type `composite_ty` by applying the |
| /// transformation function 'f' on each of the most deeply nested elements of both `c0` and `c1`. |
| /// Unlike TransformElements, this function handles the constants being of different arity, e.g. |
| /// vector-scalar, scalar-vector. |
| template <typename F> |
| ConstEval::Result TransformBinaryElements(ProgramBuilder& builder, |
| const type::Type* composite_ty, |
| F&& f, |
| const constant::Value* c0, |
| const constant::Value* c1) { |
| uint32_t n0 = c0->Type()->Elements(nullptr, 1).count; |
| uint32_t n1 = c1->Type()->Elements(nullptr, 1).count; |
| uint32_t max_n = std::max(n0, n1); |
| // If arity of both constants is 1, invoke callback |
| if (max_n == 1u) { |
| return f(c0, c1); |
| } |
| |
| const auto* element_ty = composite_ty->Elements(composite_ty).type; |
| |
| utils::Vector<const constant::Value*, 8> els; |
| els.Reserve(max_n); |
| for (uint32_t i = 0; i < max_n; i++) { |
| auto nested_or_self = [&](auto* c, uint32_t num_elems) { |
| if (num_elems == 1) { |
| return c; |
| } |
| return c->Index(i); |
| }; |
| if (auto el = TransformBinaryElements(builder, element_ty, std::forward<F>(f), |
| nested_or_self(c0, n0), nested_or_self(c1, n1))) { |
| els.Push(el.Get()); |
| } else { |
| return el.Failure(); |
| } |
| } |
| return builder.constants.Composite(composite_ty, std::move(els)); |
| } |
| } // namespace |
| |
| ConstEval::ConstEval(ProgramBuilder& b, bool use_runtime_semantics /* = false */) |
| : builder(b), use_runtime_semantics_(use_runtime_semantics) {} |
| |
| template <typename T> |
| ConstEval::Result ConstEval::CreateScalar(const Source& source, const type::Type* t, T v) { |
| static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool"); |
| TINT_ASSERT(Resolver, t->Is<type::Scalar>()); |
| |
| if constexpr (IsFloatingPoint<T>) { |
| if (!std::isfinite(v.value)) { |
| AddError(OverflowErrorMessage(v, t->FriendlyName()), source); |
| if (use_runtime_semantics_) { |
| return ZeroValue(t); |
| } else { |
| return utils::Failure; |
| } |
| } |
| } |
| return builder.constants.Get<constant::Scalar<T>>(t, v); |
| } |
| |
| const constant::Value* ConstEval::ZeroValue(const type::Type* type) { |
| return Switch( |
| type, // |
| [&](const type::Vector* v) -> const constant::Value* { |
| auto* zero_el = ZeroValue(v->type()); |
| return builder.constants.Splat(type, zero_el, v->Width()); |
| }, |
| [&](const type::Matrix* m) -> const constant::Value* { |
| auto* zero_el = ZeroValue(m->ColumnType()); |
| return builder.constants.Splat(type, zero_el, m->columns()); |
| }, |
| [&](const type::Array* a) -> const constant::Value* { |
| if (auto n = a->ConstantCount()) { |
| if (auto* zero_el = ZeroValue(a->ElemType())) { |
| return builder.constants.Splat(type, zero_el, n.value()); |
| } |
| } |
| return nullptr; |
| }, |
| [&](const type::Struct* s) -> const constant::Value* { |
| utils::Hashmap<const type::Type*, const constant::Value*, 8> zero_by_type; |
| utils::Vector<const constant::Value*, 4> zeros; |
| zeros.Reserve(s->Members().Length()); |
| for (auto* member : s->Members()) { |
| auto* zero = zero_by_type.GetOrCreate(member->Type(), |
| [&] { return ZeroValue(member->Type()); }); |
| if (!zero) { |
| return nullptr; |
| } |
| zeros.Push(zero); |
| } |
| if (zero_by_type.Count() == 1) { |
| // All members were of the same type, so the zero value is the same for all members. |
| return builder.constants.Splat(type, zeros[0], s->Members().Length()); |
| } |
| return builder.constants.Composite(s, std::move(zeros)); |
| }, |
| [&](Default) -> const constant::Value* { |
| return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* { |
| auto el = CreateScalar(Source{}, type, zero); |
| TINT_ASSERT(Resolver, el); |
| return el.Get(); |
| }); |
| }); |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Add(const Source& source, NumberT a, NumberT b) { |
| NumberT result; |
| if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) { |
| if (auto r = CheckedAdd(a, b)) { |
| result = r->value; |
| } else { |
| AddError(OverflowErrorMessage(a, "+", b), source); |
| if (use_runtime_semantics_) { |
| return NumberT{0}; |
| } else { |
| return utils::Failure; |
| } |
| } |
| } else { |
| using T = UnwrapNumber<NumberT>; |
| auto add_values = [](T lhs, T rhs) { |
| if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) { |
| // Ensure no UB for signed overflow |
| using UT = std::make_unsigned_t<T>; |
| return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs)); |
| } else { |
| return lhs + rhs; |
| } |
| }; |
| result = add_values(a.value, b.value); |
| } |
| return result; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Sub(const Source& source, NumberT a, NumberT b) { |
| NumberT result; |
| if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) { |
| if (auto r = CheckedSub(a, b)) { |
| result = r->value; |
| } else { |
| AddError(OverflowErrorMessage(a, "-", b), source); |
| if (use_runtime_semantics_) { |
| return NumberT{0}; |
| } else { |
| return utils::Failure; |
| } |
| } |
| } else { |
| using T = UnwrapNumber<NumberT>; |
| auto sub_values = [](T lhs, T rhs) { |
| if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) { |
| // Ensure no UB for signed overflow |
| using UT = std::make_unsigned_t<T>; |
| return static_cast<T>(static_cast<UT>(lhs) - static_cast<UT>(rhs)); |
| } else { |
| return lhs - rhs; |
| } |
| }; |
| result = sub_values(a.value, b.value); |
| } |
| return result; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Mul(const Source& source, NumberT a, NumberT b) { |
| using T = UnwrapNumber<NumberT>; |
| NumberT result; |
| if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) { |
| if (auto r = CheckedMul(a, b)) { |
| result = r->value; |
| } else { |
| AddError(OverflowErrorMessage(a, "*", b), source); |
| if (use_runtime_semantics_) { |
| return NumberT{0}; |
| } else { |
| return utils::Failure; |
| } |
| } |
| } else { |
| auto mul_values = [](T lhs, T rhs) { |
| if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) { |
| // For signed integrals, avoid C++ UB by multiplying as unsigned |
| using UT = std::make_unsigned_t<T>; |
| return static_cast<T>(static_cast<UT>(lhs) * static_cast<UT>(rhs)); |
| } else { |
| return lhs * rhs; |
| } |
| }; |
| result = mul_values(a.value, b.value); |
| } |
| return result; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Div(const Source& source, NumberT a, NumberT b) { |
| NumberT result; |
| if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) { |
| if (auto r = CheckedDiv(a, b)) { |
| result = r->value; |
| } else { |
| AddError(OverflowErrorMessage(a, "/", b), source); |
| if (use_runtime_semantics_) { |
| return a; |
| } else { |
| return utils::Failure; |
| } |
| } |
| } else { |
| using T = UnwrapNumber<NumberT>; |
| auto lhs = a.value; |
| auto rhs = b.value; |
| if (rhs == 0) { |
| // For integers (as for floats), lhs / 0 is an error |
| AddError(OverflowErrorMessage(a, "/", b), source); |
| if (use_runtime_semantics_) { |
| return a; |
| } else { |
| return utils::Failure; |
| } |
| } |
| if constexpr (std::is_signed_v<T>) { |
| // For signed integers, lhs / -1 where lhs is the |
| // most negative value is an error |
| if (rhs == -1 && lhs == std::numeric_limits<T>::min()) { |
| AddError(OverflowErrorMessage(a, "/", b), source); |
| if (use_runtime_semantics_) { |
| return a; |
| } else { |
| return utils::Failure; |
| } |
| } |
| } |
| result = lhs / rhs; |
| } |
| return result; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Mod(const Source& source, NumberT a, NumberT b) { |
| NumberT result; |
| if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) { |
| if (auto r = CheckedMod(a, b)) { |
| result = r->value; |
| } else { |
| AddError(OverflowErrorMessage(a, "%", b), source); |
| if (use_runtime_semantics_) { |
| return NumberT{0}; |
| } else { |
| return utils::Failure; |
| } |
| } |
| } else { |
| using T = UnwrapNumber<NumberT>; |
| auto lhs = a.value; |
| auto rhs = b.value; |
| if (rhs == 0) { |
| // lhs % 0 is an error |
| AddError(OverflowErrorMessage(a, "%", b), source); |
| if (use_runtime_semantics_) { |
| return NumberT{0}; |
| } else { |
| return utils::Failure; |
| } |
| } |
| if constexpr (std::is_signed_v<T>) { |
| // For signed integers, lhs % -1 where lhs is the |
| // most negative value is an error |
| if (rhs == -1 && lhs == std::numeric_limits<T>::min()) { |
| AddError(OverflowErrorMessage(a, "%", b), source); |
| if (use_runtime_semantics_) { |
| return NumberT{0}; |
| } else { |
| return utils::Failure; |
| } |
| } |
| } |
| result = lhs % rhs; |
| } |
| return result; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Dot2(const Source& source, |
| NumberT a1, |
| NumberT a2, |
| NumberT b1, |
| NumberT b2) { |
| auto r1 = Mul(source, a1, b1); |
| if (!r1) { |
| return utils::Failure; |
| } |
| auto r2 = Mul(source, a2, b2); |
| if (!r2) { |
| return utils::Failure; |
| } |
| auto r = Add(source, r1.Get(), r2.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| return r; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Dot3(const Source& source, |
| NumberT a1, |
| NumberT a2, |
| NumberT a3, |
| NumberT b1, |
| NumberT b2, |
| NumberT b3) { |
| auto r1 = Mul(source, a1, b1); |
| if (!r1) { |
| return utils::Failure; |
| } |
| auto r2 = Mul(source, a2, b2); |
| if (!r2) { |
| return utils::Failure; |
| } |
| auto r3 = Mul(source, a3, b3); |
| if (!r3) { |
| return utils::Failure; |
| } |
| auto r = Add(source, r1.Get(), r2.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| r = Add(source, r.Get(), r3.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| return r; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Dot4(const Source& source, |
| NumberT a1, |
| NumberT a2, |
| NumberT a3, |
| NumberT a4, |
| NumberT b1, |
| NumberT b2, |
| NumberT b3, |
| NumberT b4) { |
| auto r1 = Mul(source, a1, b1); |
| if (!r1) { |
| return utils::Failure; |
| } |
| auto r2 = Mul(source, a2, b2); |
| if (!r2) { |
| return utils::Failure; |
| } |
| auto r3 = Mul(source, a3, b3); |
| if (!r3) { |
| return utils::Failure; |
| } |
| auto r4 = Mul(source, a4, b4); |
| if (!r4) { |
| return utils::Failure; |
| } |
| auto r = Add(source, r1.Get(), r2.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| r = Add(source, r.Get(), r3.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| r = Add(source, r.Get(), r4.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| return r; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Det2(const Source& source, |
| NumberT a, |
| NumberT b, |
| NumberT c, |
| NumberT d) { |
| // | a c | |
| // | b d | |
| // |
| // = |
| // |
| // a * d - c * b |
| |
| auto r1 = Mul(source, a, d); |
| if (!r1) { |
| return utils::Failure; |
| } |
| auto r2 = Mul(source, c, b); |
| if (!r2) { |
| return utils::Failure; |
| } |
| auto r = Sub(source, r1.Get(), r2.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| return r; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Det3(const Source& source, |
| NumberT a, |
| NumberT b, |
| NumberT c, |
| NumberT d, |
| NumberT e, |
| NumberT f, |
| NumberT g, |
| NumberT h, |
| NumberT i) { |
| // | a d g | |
| // | b e h | |
| // | c f i | |
| // |
| // = |
| // |
| // a | e h | - d | b h | + g | b e | |
| // | f i | | c i | | c f | |
| |
| auto det1 = Det2(source, e, f, h, i); |
| if (!det1) { |
| return utils::Failure; |
| } |
| auto a_det1 = Mul(source, a, det1.Get()); |
| if (!a_det1) { |
| return utils::Failure; |
| } |
| auto det2 = Det2(source, b, c, h, i); |
| if (!det2) { |
| return utils::Failure; |
| } |
| auto d_det2 = Mul(source, d, det2.Get()); |
| if (!d_det2) { |
| return utils::Failure; |
| } |
| auto det3 = Det2(source, b, c, e, f); |
| if (!det3) { |
| return utils::Failure; |
| } |
| auto g_det3 = Mul(source, g, det3.Get()); |
| if (!g_det3) { |
| return utils::Failure; |
| } |
| auto r = Sub(source, a_det1.Get(), d_det2.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| return Add(source, r.Get(), g_det3.Get()); |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Det4(const Source& source, |
| NumberT a, |
| NumberT b, |
| NumberT c, |
| NumberT d, |
| NumberT e, |
| NumberT f, |
| NumberT g, |
| NumberT h, |
| NumberT i, |
| NumberT j, |
| NumberT k, |
| NumberT l, |
| NumberT m, |
| NumberT n, |
| NumberT o, |
| NumberT p) { |
| // | a e i m | |
| // | b f j n | |
| // | c g k o | |
| // | d h l p | |
| // |
| // = |
| // |
| // a | f j n | - e | b j n | + i | b f n | - m | b f j | |
| // | g k o | | c k o | | c g o | | c g k | |
| // | h l p | | d l p | | d h p | | d h l | |
| |
| auto det1 = Det3(source, f, g, h, j, k, l, n, o, p); |
| if (!det1) { |
| return utils::Failure; |
| } |
| auto a_det1 = Mul(source, a, det1.Get()); |
| if (!a_det1) { |
| return utils::Failure; |
| } |
| auto det2 = Det3(source, b, c, d, j, k, l, n, o, p); |
| if (!det2) { |
| return utils::Failure; |
| } |
| auto e_det2 = Mul(source, e, det2.Get()); |
| if (!e_det2) { |
| return utils::Failure; |
| } |
| auto det3 = Det3(source, b, c, d, f, g, h, n, o, p); |
| if (!det3) { |
| return utils::Failure; |
| } |
| auto i_det3 = Mul(source, i, det3.Get()); |
| if (!i_det3) { |
| return utils::Failure; |
| } |
| auto det4 = Det3(source, b, c, d, f, g, h, j, k, l); |
| if (!det4) { |
| return utils::Failure; |
| } |
| auto m_det4 = Mul(source, m, det4.Get()); |
| if (!m_det4) { |
| return utils::Failure; |
| } |
| auto r = Sub(source, a_det1.Get(), e_det2.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| r = Add(source, r.Get(), i_det3.Get()); |
| if (!r) { |
| return utils::Failure; |
| } |
| return Sub(source, r.Get(), m_det4.Get()); |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Sqrt(const Source& source, NumberT v) { |
| if (v < NumberT(0)) { |
| AddError("sqrt must be called with a value >= 0", source); |
| if (use_runtime_semantics_) { |
| return NumberT{0}; |
| } else { |
| return utils::Failure; |
| } |
| } |
| return NumberT{std::sqrt(v)}; |
| } |
| |
| auto ConstEval::SqrtFunc(const Source& source, const type::Type* elem_ty) { |
| return [=](auto v) -> ConstEval::Result { |
| if (auto r = Sqrt(source, v)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| template <typename NumberT> |
| utils::Result<NumberT> ConstEval::Clamp(const Source&, NumberT e, NumberT low, NumberT high) { |
| return NumberT{std::min(std::max(e, low), high)}; |
| } |
| |
| auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) { |
| return [=](auto e, auto low, auto high) -> ConstEval::Result { |
| if (auto r = Clamp(source, e, low, high)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a1, auto a2) -> ConstEval::Result { |
| if (auto r = Add(source, a1, a2)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a1, auto a2) -> ConstEval::Result { |
| if (auto r = Sub(source, a1, a2)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a1, auto a2) -> ConstEval::Result { |
| if (auto r = Mul(source, a1, a2)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a1, auto a2) -> ConstEval::Result { |
| if (auto r = Div(source, a1, a2)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a1, auto a2) -> ConstEval::Result { |
| if (auto r = Mod(source, a1, a2)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a1, auto a2, auto b1, auto b2) -> ConstEval::Result { |
| if (auto r = Dot2(source, a1, a2, b1, b2)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::Dot3Func(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ConstEval::Result { |
| if (auto r = Dot3(source, a1, a2, a3, b1, b2, b3)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::Dot4Func(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, |
| auto b4) -> ConstEval::Result { |
| if (auto r = Dot4(source, a1, a2, a3, a4, b1, b2, b3, b4)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| ConstEval::Result ConstEval::Dot(const Source& source, |
| const constant::Value* v1, |
| const constant::Value* v2) { |
| auto* vec_ty = v1->Type()->As<type::Vector>(); |
| TINT_ASSERT(Resolver, vec_ty); |
| auto* elem_ty = vec_ty->type(); |
| switch (vec_ty->Width()) { |
| case 2: |
| return Dispatch_fia_fiu32_f16( // |
| Dot2Func(source, elem_ty), // |
| v1->Index(0), v1->Index(1), // |
| v2->Index(0), v2->Index(1)); |
| case 3: |
| return Dispatch_fia_fiu32_f16( // |
| Dot3Func(source, elem_ty), // |
| v1->Index(0), v1->Index(1), v1->Index(2), // |
| v2->Index(0), v2->Index(1), v2->Index(2)); |
| case 4: |
| return Dispatch_fia_fiu32_f16( // |
| Dot4Func(source, elem_ty), // |
| v1->Index(0), v1->Index(1), v1->Index(2), v1->Index(3), // |
| v2->Index(0), v2->Index(1), v2->Index(2), v2->Index(3)); |
| } |
| TINT_ICE(Resolver, builder.Diagnostics()) << "Expected vector"; |
| return utils::Failure; |
| } |
| |
| ConstEval::Result ConstEval::Length(const Source& source, |
| const type::Type* ty, |
| const constant::Value* c0) { |
| auto* vec_ty = c0->Type()->As<type::Vector>(); |
| // Evaluates to the absolute value of e if T is scalar. |
| if (vec_ty == nullptr) { |
| auto create = [&](auto e) { |
| using NumberT = decltype(e); |
| return CreateScalar(source, ty, NumberT{std::abs(e)}); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| } |
| |
| // Evaluates to sqrt(e[0]^2 + e[1]^2 + ...) if T is a vector type. |
| auto d = Dot(source, c0, c0); |
| if (!d) { |
| return utils::Failure; |
| } |
| return Dispatch_fa_f32_f16(SqrtFunc(source, ty), d.Get()); |
| } |
| |
| ConstEval::Result ConstEval::Mul(const Source& source, |
| const type::Type* ty, |
| const constant::Value* v1, |
| const constant::Value* v2) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| return Dispatch_fia_fiu32_f16(MulFunc(source, c0->Type()), c0, c1); |
| }; |
| return TransformBinaryElements(builder, ty, transform, v1, v2); |
| } |
| |
| ConstEval::Result ConstEval::Sub(const Source& source, |
| const type::Type* ty, |
| const constant::Value* v1, |
| const constant::Value* v2) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| return Dispatch_fia_fiu32_f16(SubFunc(source, c0->Type()), c0, c1); |
| }; |
| return TransformBinaryElements(builder, ty, transform, v1, v2); |
| } |
| |
| auto ConstEval::Det2Func(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a, auto b, auto c, auto d) -> ConstEval::Result { |
| if (auto r = Det2(source, a, b, c, d)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::Det3Func(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, |
| auto i) -> ConstEval::Result { |
| if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| auto ConstEval::Det4Func(const Source& source, const type::Type* elem_ty) { |
| return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j, |
| auto k, auto l, auto m, auto n, auto o, auto p) -> ConstEval::Result { |
| if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) { |
| return CreateScalar(source, elem_ty, r.Get()); |
| } |
| return utils::Failure; |
| }; |
| } |
| |
| ConstEval::Result ConstEval::Literal(const type::Type* ty, const ast::LiteralExpression* literal) { |
| auto& source = literal->source; |
| return Switch( |
| literal, |
| [&](const ast::BoolLiteralExpression* lit) { return CreateScalar(source, ty, lit->value); }, |
| [&](const ast::IntLiteralExpression* lit) -> ConstEval::Result { |
| switch (lit->suffix) { |
| case ast::IntLiteralExpression::Suffix::kNone: |
| return CreateScalar(source, ty, AInt(lit->value)); |
| case ast::IntLiteralExpression::Suffix::kI: |
| return CreateScalar(source, ty, i32(lit->value)); |
| case ast::IntLiteralExpression::Suffix::kU: |
| return CreateScalar(source, ty, u32(lit->value)); |
| } |
| return nullptr; |
| }, |
| [&](const ast::FloatLiteralExpression* lit) -> ConstEval::Result { |
| switch (lit->suffix) { |
| case ast::FloatLiteralExpression::Suffix::kNone: |
| return CreateScalar(source, ty, AFloat(lit->value)); |
| case ast::FloatLiteralExpression::Suffix::kF: |
| return CreateScalar(source, ty, f32(lit->value)); |
| case ast::FloatLiteralExpression::Suffix::kH: |
| return CreateScalar(source, ty, f16(lit->value)); |
| } |
| return nullptr; |
| }); |
| } |
| |
| ConstEval::Result ConstEval::ArrayOrStructCtor(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args) { |
| if (args.IsEmpty()) { |
| return ZeroValue(ty); |
| } |
| |
| if (args.Length() == 1 && args[0]->Type() == ty) { |
| // Identity constructor. |
| return args[0]; |
| } |
| |
| // Multiple arguments. Must be a value constructor. |
| return builder.constants.Composite(ty, std::move(args)); |
| } |
| |
| ConstEval::Result ConstEval::Conv(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto* el_ty = ty->Elements(ty).type; |
| if (!el_ty) { |
| return nullptr; |
| } |
| |
| if (!args[0]) { |
| return nullptr; // Single argument is not constant. |
| } |
| |
| return Convert(ty, args[0], source); |
| } |
| |
| ConstEval::Result ConstEval::Zero(const type::Type* ty, |
| utils::VectorRef<const constant::Value*>, |
| const Source&) { |
| return ZeroValue(ty); |
| } |
| |
| ConstEval::Result ConstEval::Identity(const type::Type*, |
| utils::VectorRef<const constant::Value*> args, |
| const Source&) { |
| return args[0]; |
| } |
| |
| ConstEval::Result ConstEval::VecSplat(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source&) { |
| if (auto* arg = args[0]) { |
| return builder.constants.Splat(ty, arg, static_cast<const type::Vector*>(ty)->Width()); |
| } |
| return nullptr; |
| } |
| |
| ConstEval::Result ConstEval::VecInitS(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source&) { |
| return builder.constants.Composite(ty, args); |
| } |
| |
| ConstEval::Result ConstEval::VecInitM(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source&) { |
| utils::Vector<const constant::Value*, 4> els; |
| for (auto* arg : args) { |
| auto* val = arg; |
| if (!val) { |
| return nullptr; |
| } |
| auto* arg_ty = arg->Type(); |
| if (auto* arg_vec = arg_ty->As<type::Vector>()) { |
| // Extract out vector elements. |
| for (uint32_t j = 0; j < arg_vec->Width(); j++) { |
| auto* el = val->Index(j); |
| if (!el) { |
| return nullptr; |
| } |
| els.Push(el); |
| } |
| } else { |
| els.Push(val); |
| } |
| } |
| return builder.constants.Composite(ty, std::move(els)); |
| } |
| |
| ConstEval::Result ConstEval::MatInitS(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source&) { |
| auto* m = static_cast<const type::Matrix*>(ty); |
| |
| utils::Vector<const constant::Value*, 4> els; |
| for (uint32_t c = 0; c < m->columns(); c++) { |
| utils::Vector<const constant::Value*, 4> column; |
| for (uint32_t r = 0; r < m->rows(); r++) { |
| auto i = r + c * m->rows(); |
| column.Push(args[i]); |
| } |
| els.Push(builder.constants.Composite(m->ColumnType(), std::move(column))); |
| } |
| return builder.constants.Composite(ty, std::move(els)); |
| } |
| |
| ConstEval::Result ConstEval::MatInitV(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source&) { |
| return builder.constants.Composite(ty, args); |
| } |
| |
| ConstEval::Result ConstEval::Index(const type::Type* ty, |
| const sem::ValueExpression* obj_expr, |
| const sem::ValueExpression* idx_expr) { |
| auto idx_val = idx_expr->ConstantValue(); |
| if (!idx_val) { |
| return nullptr; |
| } |
| |
| uint32_t el_count = obj_expr->Type()->UnwrapRef()->Elements().count; |
| |
| AInt idx = idx_val->ValueAs<AInt>(); |
| if (idx < 0 || (el_count > 0 && idx >= el_count)) { |
| std::string range; |
| if (el_count > 0) { |
| range = " [0.." + std::to_string(el_count - 1) + "]"; |
| } |
| AddError("index " + std::to_string(idx) + " out of bounds" + range, |
| idx_expr->Declaration()->source); |
| if (use_runtime_semantics_) { |
| return ZeroValue(ty); |
| } else { |
| return utils::Failure; |
| } |
| } |
| |
| auto obj_val = obj_expr->ConstantValue(); |
| if (!obj_val) { |
| return nullptr; |
| } |
| |
| return obj_val->Index(static_cast<size_t>(idx)); |
| } |
| |
| ConstEval::Result ConstEval::MemberAccess(const sem::ValueExpression* obj_expr, |
| const type::StructMember* member) { |
| auto obj_val = obj_expr->ConstantValue(); |
| if (!obj_val) { |
| return nullptr; |
| } |
| return obj_val->Index(static_cast<size_t>(member->Index())); |
| } |
| |
| ConstEval::Result ConstEval::Swizzle(const type::Type* ty, |
| const sem::ValueExpression* vec_expr, |
| utils::VectorRef<uint32_t> indices) { |
| auto* vec_val = vec_expr->ConstantValue(); |
| if (!vec_val) { |
| return nullptr; |
| } |
| if (indices.Length() == 1) { |
| return vec_val->Index(static_cast<size_t>(indices[0])); |
| } |
| auto values = utils::Transform<4>( |
| indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); }); |
| return builder.constants.Composite(ty, std::move(values)); |
| } |
| |
| ConstEval::Result ConstEval::Bitcast(const type::Type* ty, |
| const constant::Value* value, |
| const Source& source) { |
| auto* el_ty = ty->DeepestElement(); |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto e) { |
| return Switch( |
| el_ty, |
| [&](const type::U32*) { // |
| auto r = utils::Bitcast<u32>(e); |
| return CreateScalar(source, el_ty, r); |
| }, |
| [&](const type::I32*) { // |
| auto r = utils::Bitcast<i32>(e); |
| return CreateScalar(source, el_ty, r); |
| }, |
| [&](const type::F32*) { // |
| auto r = utils::Bitcast<f32>(e); |
| return CreateScalar(source, el_ty, r); |
| }); |
| }; |
| return Dispatch_fiu32(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, value); |
| } |
| |
| ConstEval::Result ConstEval::OpComplement(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c) { |
| auto create = [&](auto i) { |
| return CreateScalar(source, c->Type(), decltype(i)(~i.value)); |
| }; |
| return Dispatch_ia_iu32(create, c); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::OpUnaryMinus(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c) { |
| auto create = [&](auto i) { |
| // For signed integrals, avoid C++ UB by not negating the |
| // smallest negative number. In WGSL, this operation is well |
| // defined to return the same value, see: |
| // https://gpuweb.github.io/gpuweb/wgsl/#arithmetic-expr. |
| using T = UnwrapNumber<decltype(i)>; |
| if constexpr (std::is_integral_v<T>) { |
| auto v = i.value; |
| if (v != std::numeric_limits<T>::min()) { |
| v = -v; |
| } |
| return CreateScalar(source, c->Type(), decltype(i)(v)); |
| } else { |
| return CreateScalar(source, c->Type(), decltype(i)(-i.value)); |
| } |
| }; |
| return Dispatch_fia_fi32_f16(create, c); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::OpNot(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c) { |
| auto create = [&](auto i) { return CreateScalar(source, c->Type(), decltype(i)(!i)); }; |
| return Dispatch_bool(create, c); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::OpPlus(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| return Dispatch_fia_fiu32_f16(AddFunc(source, c0->Type()), c0, c1); |
| }; |
| |
| return TransformBinaryElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpMinus(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| return Sub(source, ty, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpMultiply(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| return Mul(source, ty, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto* mat_ty = args[0]->Type()->As<type::Matrix>(); |
| auto* vec_ty = args[1]->Type()->As<type::Vector>(); |
| auto* elem_ty = vec_ty->type(); |
| |
| auto dot = [&](const constant::Value* m, size_t row, const constant::Value* v) { |
| ConstEval::Result result; |
| switch (mat_ty->columns()) { |
| case 2: |
| result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), // |
| m->Index(0)->Index(row), // |
| m->Index(1)->Index(row), // |
| v->Index(0), // |
| v->Index(1)); |
| break; |
| case 3: |
| result = Dispatch_fa_f32_f16(Dot3Func(source, elem_ty), // |
| m->Index(0)->Index(row), // |
| m->Index(1)->Index(row), // |
| m->Index(2)->Index(row), // |
| v->Index(0), // |
| v->Index(1), v->Index(2)); |
| break; |
| case 4: |
| result = Dispatch_fa_f32_f16(Dot4Func(source, elem_ty), // |
| m->Index(0)->Index(row), // |
| m->Index(1)->Index(row), // |
| m->Index(2)->Index(row), // |
| m->Index(3)->Index(row), // |
| v->Index(0), // |
| v->Index(1), // |
| v->Index(2), // |
| v->Index(3)); |
| break; |
| } |
| return result; |
| }; |
| |
| utils::Vector<const constant::Value*, 4> result; |
| for (size_t i = 0; i < mat_ty->rows(); ++i) { |
| auto r = dot(args[0], i, args[1]); // matrix row i * vector |
| if (!r) { |
| return utils::Failure; |
| } |
| result.Push(r.Get()); |
| } |
| return builder.constants.Composite(ty, result); |
| } |
| ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto* vec_ty = args[0]->Type()->As<type::Vector>(); |
| auto* mat_ty = args[1]->Type()->As<type::Matrix>(); |
| auto* elem_ty = vec_ty->type(); |
| |
| auto dot = [&](const constant::Value* v, const constant::Value* m, size_t col) { |
| ConstEval::Result result; |
| switch (mat_ty->rows()) { |
| case 2: |
| result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), // |
| m->Index(col)->Index(0), // |
| m->Index(col)->Index(1), // |
| v->Index(0), // |
| v->Index(1)); |
| break; |
| case 3: |
| result = Dispatch_fa_f32_f16(Dot3Func(source, elem_ty), // |
| m->Index(col)->Index(0), // |
| m->Index(col)->Index(1), // |
| m->Index(col)->Index(2), |
| v->Index(0), // |
| v->Index(1), // |
| v->Index(2)); |
| break; |
| case 4: |
| result = Dispatch_fa_f32_f16(Dot4Func(source, elem_ty), // |
| m->Index(col)->Index(0), // |
| m->Index(col)->Index(1), // |
| m->Index(col)->Index(2), // |
| m->Index(col)->Index(3), // |
| v->Index(0), // |
| v->Index(1), // |
| v->Index(2), // |
| v->Index(3)); |
| } |
| return result; |
| }; |
| |
| utils::Vector<const constant::Value*, 4> result; |
| for (size_t i = 0; i < mat_ty->columns(); ++i) { |
| auto r = dot(args[0], args[1], i); // vector * matrix col i |
| if (!r) { |
| return utils::Failure; |
| } |
| result.Push(r.Get()); |
| } |
| return builder.constants.Composite(ty, result); |
| } |
| |
| ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto* mat1 = args[0]; |
| auto* mat2 = args[1]; |
| auto* mat1_ty = mat1->Type()->As<type::Matrix>(); |
| auto* mat2_ty = mat2->Type()->As<type::Matrix>(); |
| auto* elem_ty = mat1_ty->type(); |
| |
| auto dot = [&](const constant::Value* m1, size_t row, const constant::Value* m2, size_t col) { |
| auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); }; |
| auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); }; |
| |
| ConstEval::Result result; |
| switch (mat1_ty->columns()) { |
| case 2: |
| result = Dispatch_fa_f32_f16(Dot2Func(source, elem_ty), // |
| m1e(row, 0), // |
| m1e(row, 1), // |
| m2e(0, col), // |
| m2e(1, col)); |
| break; |
| case 3: |
| result = Dispatch_fa_f32_f16(Dot3Func(source, elem_ty), // |
| m1e(row, 0), // |
| m1e(row, 1), // |
| m1e(row, 2), // |
| m2e(0, col), // |
| m2e(1, col), // |
| m2e(2, col)); |
| break; |
| case 4: |
| result = Dispatch_fa_f32_f16(Dot4Func(source, elem_ty), // |
| m1e(row, 0), // |
| m1e(row, 1), // |
| m1e(row, 2), // |
| m1e(row, 3), // |
| m2e(0, col), // |
| m2e(1, col), // |
| m2e(2, col), // |
| m2e(3, col)); |
| break; |
| } |
| return result; |
| }; |
| |
| utils::Vector<const constant::Value*, 4> result_mat; |
| for (size_t c = 0; c < mat2_ty->columns(); ++c) { |
| utils::Vector<const constant::Value*, 4> col_vec; |
| for (size_t r = 0; r < mat1_ty->rows(); ++r) { |
| auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c |
| if (!v) { |
| return utils::Failure; |
| } |
| col_vec.Push(v.Get()); // mat1 row r * mat2 col c |
| } |
| |
| // Add column vector to matrix |
| auto* col_vec_ty = ty->As<type::Matrix>()->ColumnType(); |
| result_mat.Push(builder.constants.Composite(col_vec_ty, col_vec)); |
| } |
| return builder.constants.Composite(ty, result_mat); |
| } |
| |
| ConstEval::Result ConstEval::OpDivide(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| return Dispatch_fia_fiu32_f16(DivFunc(source, c0->Type()), c0, c1); |
| }; |
| |
| return TransformBinaryElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpModulo(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| return Dispatch_fia_fiu32_f16(ModFunc(source, c0->Type()), c0, c1); |
| }; |
| |
| return TransformBinaryElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpEqual(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| return CreateScalar(source, ty->DeepestElement(), i == j); |
| }; |
| return Dispatch_fia_fiu32_f16_bool(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpNotEqual(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| return CreateScalar(source, ty->DeepestElement(), i != j); |
| }; |
| return Dispatch_fia_fiu32_f16_bool(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpLessThan(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| return CreateScalar(source, ty->DeepestElement(), i < j); |
| }; |
| return Dispatch_fia_fiu32_f16(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpGreaterThan(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| return CreateScalar(source, ty->DeepestElement(), i > j); |
| }; |
| return Dispatch_fia_fiu32_f16(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpLessThanEqual(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| return CreateScalar(source, ty->DeepestElement(), i <= j); |
| }; |
| return Dispatch_fia_fiu32_f16(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpGreaterThanEqual(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| return CreateScalar(source, ty->DeepestElement(), i >= j); |
| }; |
| return Dispatch_fia_fiu32_f16(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| // Due to short-circuiting, this function is only called if lhs is true, so we only return the |
| // value of the rhs. |
| TINT_ASSERT(Resolver, args[0]->ValueAs<bool>()); |
| return CreateScalar(source, ty, args[1]->ValueAs<bool>()); |
| } |
| |
| ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| // Due to short-circuiting, this function is only called if lhs is false, so we only only return |
| // the value of the rhs. |
| TINT_ASSERT(Resolver, !args[0]->ValueAs<bool>()); |
| return CreateScalar(source, ty, args[1]->ValueAs<bool>()); |
| } |
| |
| ConstEval::Result ConstEval::OpAnd(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| using T = decltype(i); |
| T result; |
| if constexpr (std::is_same_v<T, bool>) { |
| result = i && j; |
| } else { // integral |
| result = i & j; |
| } |
| return CreateScalar(source, ty->DeepestElement(), result); |
| }; |
| return Dispatch_ia_iu32_bool(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpOr(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| using T = decltype(i); |
| T result; |
| if constexpr (std::is_same_v<T, bool>) { |
| result = i || j; |
| } else { // integral |
| result = i | j; |
| } |
| return CreateScalar(source, ty->DeepestElement(), result); |
| }; |
| return Dispatch_ia_iu32_bool(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpXor(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) -> ConstEval::Result { |
| return CreateScalar(source, ty->DeepestElement(), decltype(i){i ^ j}); |
| }; |
| return Dispatch_ia_iu32(create, c0, c1); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpShiftLeft(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto e1, auto e2) -> ConstEval::Result { |
| using NumberT = decltype(e1); |
| using T = UnwrapNumber<NumberT>; |
| using UT = std::make_unsigned_t<T>; |
| constexpr size_t bit_width = BitWidth<NumberT>; |
| UT e1u = static_cast<UT>(e1); |
| UT e2u = static_cast<UT>(e2); |
| |
| if constexpr (IsAbstract<NumberT>) { |
| // The e2 + 1 most significant bits of e1 must have the same bit value, otherwise |
| // sign change (overflow) would occur. |
| // Check sign change only if e2 is less than bit width of e1. If e1 is larger |
| // than bit width, we check for non-representable value below. |
| if (e2u < bit_width) { |
| UT must_match_msb = e2u + 1; |
| UT mask = ~UT{0} << (bit_width - must_match_msb); |
| if ((e1u & mask) != 0 && (e1u & mask) != mask) { |
| AddError("shift left operation results in sign change", source); |
| if (!use_runtime_semantics_) { |
| return utils::Failure; |
| } |
| } |
| } else { |
| // If shift value >= bit_width, then any non-zero value would overflow |
| if (e1 != 0) { |
| AddError(OverflowErrorMessage(e1, "<<", e2), source); |
| if (!use_runtime_semantics_) { |
| return utils::Failure; |
| } |
| } |
| |
| // It's UB in C++ to shift by greater or equal to the bit width (even if the lhs |
| // is 0), so we make sure to avoid this by setting the shift value to 0. |
| e2u = 0; |
| } |
| } else { |
| if (static_cast<size_t>(e2) >= bit_width) { |
| // At shader/pipeline-creation time, it is an error to shift by the bit width of |
| // the lhs or greater. |
| // NOTE: At runtime, we shift by e2 % (bit width of e1). |
| AddError( |
| "shift left value must be less than the bit width of the lhs, which is " + |
| std::to_string(bit_width), |
| source); |
| if (use_runtime_semantics_) { |
| e2u = e2u % bit_width; |
| } else { |
| return utils::Failure; |
| } |
| } |
| |
| if constexpr (std::is_signed_v<T>) { |
| // If T is a signed integer type, and the e2+1 most significant bits of e1 do |
| // not have the same bit value, then error. |
| size_t must_match_msb = e2u + 1; |
| UT mask = ~UT{0} << (bit_width - must_match_msb); |
| if ((e1u & mask) != 0 && (e1u & mask) != mask) { |
| AddError("shift left operation results in sign change", source); |
| if (!use_runtime_semantics_) { |
| return utils::Failure; |
| } |
| } |
| } else { |
| // If T is an unsigned integer type, and any of the e2 most significant bits of |
| // e1 are 1, then error. |
| if (e2u > 0) { |
| size_t must_be_zero_msb = e2u; |
| UT mask = ~UT{0} << (bit_width - must_be_zero_msb); |
| if ((e1u & mask) != 0) { |
| AddError(OverflowErrorMessage(e1, "<<", e2), source); |
| if (!use_runtime_semantics_) { |
| return utils::Failure; |
| } |
| } |
| } |
| } |
| } |
| |
| // Avoid UB by left shifting as unsigned value |
| auto result = static_cast<T>(static_cast<UT>(e1) << e2u); |
| return CreateScalar(source, ty->DeepestElement(), NumberT{result}); |
| }; |
| return Dispatch_ia_iu32(create, c0, c1); |
| }; |
| |
| if (TINT_UNLIKELY(!args[1]->Type()->DeepestElement()->Is<type::U32>())) { |
| TINT_ICE(Resolver, builder.Diagnostics()) |
| << "Element type of rhs of ShiftLeft must be a u32"; |
| return utils::Failure; |
| } |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::OpShiftRight(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto e1, auto e2) -> ConstEval::Result { |
| using NumberT = decltype(e1); |
| using T = UnwrapNumber<NumberT>; |
| using UT = std::make_unsigned_t<T>; |
| const size_t bit_width = BitWidth<NumberT>; |
| UT e1u = static_cast<UT>(e1); |
| UT e2u = static_cast<UT>(e2); |
| |
| auto signed_shift_right = [&] { |
| // In C++, right shift of a signed negative number is implementation-defined. |
| // Although most implementations sign-extend, we do it manually to ensure it works |
| // correctly on all implementations. |
| const UT msb = UT{1} << (bit_width - 1); |
| UT sign_ext = 0; |
| if (e1u & msb) { |
| // Set e2 + 1 bits to 1 |
| UT num_shift_bits_mask = ((UT{1} << e2u) - UT{1}); |
| sign_ext = (num_shift_bits_mask << (bit_width - e2u - UT{1})) | msb; |
| } |
| return static_cast<T>((e1u >> e2u) | sign_ext); |
| }; |
| |
| T result = 0; |
| if constexpr (IsAbstract<NumberT>) { |
| if (static_cast<size_t>(e2) >= bit_width) { |
| result = T{0}; |
| } else { |
| result = signed_shift_right(); |
| } |
| } else { |
| if (static_cast<size_t>(e2) >= bit_width) { |
| // At shader/pipeline-creation time, it is an error to shift by the bit width of |
| // the lhs or greater. NOTE: At runtime, we shift by e2 % (bit width of e1). |
| AddError( |
| "shift right value must be less than the bit width of the lhs, which is " + |
| std::to_string(bit_width), |
| source); |
| if (use_runtime_semantics_) { |
| e2u = e2u % bit_width; |
| } else { |
| return utils::Failure; |
| } |
| } |
| |
| if constexpr (std::is_signed_v<T>) { |
| result = signed_shift_right(); |
| } else { |
| result = e1 >> e2u; |
| } |
| } |
| return CreateScalar(source, ty->DeepestElement(), NumberT{result}); |
| }; |
| return Dispatch_ia_iu32(create, c0, c1); |
| }; |
| |
| if (TINT_UNLIKELY(!args[1]->Type()->DeepestElement()->Is<type::U32>())) { |
| TINT_ICE(Resolver, builder.Diagnostics()) |
| << "Element type of rhs of ShiftLeft must be a u32"; |
| return utils::Failure; |
| } |
| |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::abs(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto e) { |
| using NumberT = decltype(e); |
| NumberT result; |
| if constexpr (IsUnsignedIntegral<NumberT>) { |
| result = e; |
| } else if constexpr (IsSignedIntegral<NumberT>) { |
| if (e == NumberT::Lowest()) { |
| result = e; |
| } else { |
| result = NumberT{std::abs(e)}; |
| } |
| } else { |
| result = NumberT{std::abs(e)}; |
| } |
| return CreateScalar(source, c0->Type(), result); |
| }; |
| return Dispatch_fia_fiu32_f16(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::acos(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto i) -> ConstEval::Result { |
| using NumberT = decltype(i); |
| if (i < NumberT(-1.0) || i > NumberT(1.0)) { |
| AddError("acos must be called with a value in the range [-1 .. 1] (inclusive)", |
| source); |
| if (use_runtime_semantics_) { |
| return ZeroValue(c0->Type()); |
| } else { |
| return utils::Failure; |
| } |
| } |
| return CreateScalar(source, c0->Type(), NumberT(std::acos(i.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::acosh(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto i) -> ConstEval::Result { |
| using NumberT = decltype(i); |
| if (i < NumberT(1.0)) { |
| AddError("acosh must be called with a value >= 1.0", source); |
| if (use_runtime_semantics_) { |
| return ZeroValue(c0->Type()); |
| } else { |
| return utils::Failure; |
| } |
| } |
| return CreateScalar(source, c0->Type(), NumberT(std::acosh(i.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::all(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| return CreateScalar(source, ty, !args[0]->AnyZero()); |
| } |
| |
| ConstEval::Result ConstEval::any(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| return CreateScalar(source, ty, !args[0]->AllZero()); |
| } |
| |
| ConstEval::Result ConstEval::asin(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto i) -> ConstEval::Result { |
| using NumberT = decltype(i); |
| if (i < NumberT(-1.0) || i > NumberT(1.0)) { |
| AddError("asin must be called with a value in the range [-1 .. 1] (inclusive)", |
| source); |
| if (use_runtime_semantics_) { |
| return ZeroValue(c0->Type()); |
| } else { |
| return utils::Failure; |
| } |
| } |
| return CreateScalar(source, c0->Type(), NumberT(std::asin(i.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::asinh(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto i) { |
| return CreateScalar(source, c0->Type(), decltype(i)(std::asinh(i.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::atan(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto i) { |
| return CreateScalar(source, c0->Type(), decltype(i)(std::atan(i.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::atanh(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto i) -> ConstEval::Result { |
| using NumberT = decltype(i); |
| if (i <= NumberT(-1.0) || i >= NumberT(1.0)) { |
| AddError("atanh must be called with a value in the range (-1 .. 1) (exclusive)", |
| source); |
| if (use_runtime_semantics_) { |
| return ZeroValue(c0->Type()); |
| } else { |
| return utils::Failure; |
| } |
| } |
| return CreateScalar(source, c0->Type(), NumberT(std::atanh(i.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::atan2(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1) { |
| auto create = [&](auto i, auto j) { |
| return CreateScalar(source, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0, c1); |
| }; |
| return TransformElements(builder, ty, transform, args[0], args[1]); |
| } |
| |
| ConstEval::Result ConstEval::ceil(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto e) { |
| return CreateScalar(source, c0->Type(), decltype(e)(std::ceil(e))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::clamp(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0, const constant::Value* c1, |
| const constant::Value* c2) { |
| return Dispatch_fia_fiu32_f16(ClampFunc(source, c0->Type()), c0, c1, c2); |
| }; |
| return TransformElements(builder, ty, transform, args[0], args[1], args[2]); |
| } |
| |
| ConstEval::Result ConstEval::cos(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto i) -> ConstEval::Result { |
| using NumberT = decltype(i); |
| return CreateScalar(source, c0->Type(), NumberT(std::cos(i.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::cosh(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto i) -> ConstEval::Result { |
| using NumberT = decltype(i); |
| return CreateScalar(source, c0->Type(), NumberT(std::cosh(i.value))); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::countLeadingZeros(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto e) { |
| using NumberT = decltype(e); |
| using T = UnwrapNumber<NumberT>; |
| auto count = CountLeadingBits(T{e}, T{0}); |
| return CreateScalar(source, c0->Type(), NumberT(count)); |
| }; |
| return Dispatch_iu32(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::countOneBits(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto e) { |
| using NumberT = decltype(e); |
| using T = UnwrapNumber<NumberT>; |
| using UT = std::make_unsigned_t<T>; |
| constexpr UT kRightMost = UT{1}; |
| |
| auto count = UT{0}; |
| for (auto v = static_cast<UT>(e); v != UT{0}; v >>= 1) { |
| if ((v & kRightMost) == 1) { |
| ++count; |
| } |
| } |
| |
| return CreateScalar(source, c0->Type(), NumberT(count)); |
| }; |
| return Dispatch_iu32(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::countTrailingZeros(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto e) { |
| using NumberT = decltype(e); |
| using T = UnwrapNumber<NumberT>; |
| auto count = CountTrailingBits(T{e}, T{0}); |
| return CreateScalar(source, c0->Type(), NumberT(count)); |
| }; |
| return Dispatch_iu32(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::cross(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto* u = args[0]; |
| auto* v = args[1]; |
| auto* elem_ty = u->Type()->As<type::Vector>()->type(); |
| |
| // cross product of a v3 is the determinant of the 3x3 matrix: |
| // |
| // |i j k | |
| // |u0 u1 u2| |
| // |v0 v1 v2| |
| // |
| // |u1 u2|i - |u0 u2|j + |u0 u1|k |
| // |v1 v2| |v0 v2| |v0 v1| |
| // |
| // |u1 u2|i + |v0 v2|j + |u0 u1|k |
| // |v1 v2| |u0 u2| |v0 v1| |
| |
| auto* u0 = u->Index(0); |
| auto* u1 = u->Index(1); |
| auto* u2 = u->Index(2); |
| auto* v0 = v->Index(0); |
| auto* v1 = v->Index(1); |
| auto* v2 = v->Index(2); |
| |
| auto x = Dispatch_fa_f32_f16(Det2Func(source, elem_ty), u1, u2, v1, v2); |
| if (!x) { |
| return utils::Failure; |
| } |
| auto y = Dispatch_fa_f32_f16(Det2Func(source, elem_ty), v0, v2, u0, u2); |
| if (!y) { |
| return utils::Failure; |
| } |
| auto z = Dispatch_fa_f32_f16(Det2Func(source, elem_ty), u0, u1, v0, v1); |
| if (!z) { |
| return utils::Failure; |
| } |
| |
| return builder.constants.Composite( |
| ty, utils::Vector<const constant::Value*, 3>{x.Get(), y.Get(), z.Get()}); |
| } |
| |
| ConstEval::Result ConstEval::degrees(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto transform = [&](const constant::Value* c0) { |
| auto create = [&](auto e) -> ConstEval::Result { |
| using NumberT = decltype(e); |
| using T = UnwrapNumber<NumberT>; |
| |
| auto pi = kPi<T>; |
| auto scale = Div(source, NumberT(180), NumberT(pi)); |
| if (!scale) { |
| AddNote("when calculating degrees", source); |
| return utils::Failure; |
| } |
| auto result = Mul(source, e, scale.Get()); |
| if (!result) { |
| AddNote("when calculating degrees", source); |
| return utils::Failure; |
| } |
| return CreateScalar(source, c0->Type(), result.Get()); |
| }; |
| return Dispatch_fa_f32_f16(create, c0); |
| }; |
| return TransformElements(builder, ty, transform, args[0]); |
| } |
| |
| ConstEval::Result ConstEval::determinant(const type::Type* ty, |
| utils::VectorRef<const constant::Value*> args, |
| const Source& source) { |
| auto calculate = [&]() -> ConstEval::Result { |
| auto* m = args[0]; |
| auto* mat_ty = m->Type()->As<type::Matrix>(); |
| auto me = [&](size_t r, size_t c) { return m->Index(c)->Index(r); }; |
| switch (mat_ty->rows()) { |
| case 2: |
| return Dispatch_fa_f32_f16(Det2Func(source, ty), // |
| me(0, 0), me(1, 0), // |
| me(0, 1), me(1, 1)); |
| |