tint/const-eval: Add flag to use runtime semantics
Add a flag to resolver::ConstEval to turn all overflow and range
errors into warnings, and return a valid (usually zero) value instead
of utils::Failure as defined by the WGSL spec for expressions
evaluated at runtime.
Change-Id: Icdce512306aabe717591134a1b4ba2d9c668f29c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118640
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn
index e49c128..dfce1a4 100644
--- a/src/tint/BUILD.gn
+++ b/src/tint/BUILD.gn
@@ -1442,6 +1442,7 @@
"resolver/const_eval_conversion_test.cc",
"resolver/const_eval_indexing_test.cc",
"resolver/const_eval_member_access_test.cc",
+ "resolver/const_eval_runtime_semantics_test.cc",
"resolver/const_eval_test.h",
"resolver/const_eval_unary_op_test.cc",
"resolver/control_block_validation_test.cc",
diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt
index 93c8f98..f4df7c9 100644
--- a/src/tint/CMakeLists.txt
+++ b/src/tint/CMakeLists.txt
@@ -936,6 +936,7 @@
resolver/const_eval_conversion_test.cc
resolver/const_eval_indexing_test.cc
resolver/const_eval_member_access_test.cc
+ resolver/const_eval_runtime_semantics_test.cc
resolver/const_eval_test.h
resolver/const_eval_unary_op_test.cc
resolver/control_block_validation_test.cc
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 0517629..09878c4 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -247,7 +247,8 @@
ConstEval::Result ScalarConvert(const constant::Scalar<T>* scalar,
ProgramBuilder& builder,
const type::Type* target_ty,
- const Source& source) {
+ const Source& source,
+ bool use_runtime_semantics) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == scalar->type) {
// If the types are identical, then no conversion is needed.
@@ -271,17 +272,35 @@
// --- Below this point are the failure cases ---
} else if constexpr (IsAbstract<FROM>) {
// [abstract-numeric -> x] - materialization failure
- builder.Diagnostics().add_error(
- tint::diag::System::Resolver,
- OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source);
- return utils::Failure;
+ auto msg = OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty));
+ if (use_runtime_semantics) {
+ builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, source);
+ switch (conv.Failure()) {
+ case ConversionFailure::kExceedsNegativeLimit:
+ return builder.create<constant::Scalar<TO>>(target_ty, TO::Lowest());
+ case ConversionFailure::kExceedsPositiveLimit:
+ return builder.create<constant::Scalar<TO>>(target_ty, TO::Highest());
+ }
+ } else {
+ builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, source);
+ return utils::Failure;
+ }
} else if constexpr (IsFloatingPoint<TO>) {
// [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
- builder.Diagnostics().add_error(
- tint::diag::System::Resolver,
- OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty)), source);
- return utils::Failure;
+ auto msg = OverflowErrorMessage(scalar->value, builder.FriendlyName(target_ty));
+ if (use_runtime_semantics) {
+ builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, source);
+ switch (conv.Failure()) {
+ case ConversionFailure::kExceedsNegativeLimit:
+ return builder.create<constant::Scalar<TO>>(target_ty, TO::Lowest());
+ case ConversionFailure::kExceedsPositiveLimit:
+ return builder.create<constant::Scalar<TO>>(target_ty, TO::Highest());
+ }
+ } else {
+ builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, source);
+ return utils::Failure;
+ }
} else if constexpr (IsFloatingPoint<FROM>) {
// [floating-point -> integer] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
@@ -305,14 +324,18 @@
ConstEval::Result ConvertInternal(const constant::Value* c,
ProgramBuilder& builder,
const type::Type* target_ty,
- const Source& source);
+ const Source& source,
+ bool use_runtime_semantics);
+const constant::Value* ZeroValue(ProgramBuilder& builder, const type::Type* type);
ConstEval::Result SplatConvert(const constant::Splat* splat,
ProgramBuilder& builder,
const type::Type* target_ty,
- const Source& source) {
+ const Source& source,
+ bool use_runtime_semantics) {
// Convert the single splatted element type.
- auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source);
+ auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source,
+ use_runtime_semantics);
if (!conv_el) {
return utils::Failure;
}
@@ -325,7 +348,8 @@
ConstEval::Result CompositeConvert(const constant::Composite* composite,
ProgramBuilder& builder,
const type::Type* target_ty,
- const Source& source) {
+ const Source& source,
+ bool use_runtime_semantics) {
// Convert each of the composite element types.
utils::Vector<const constant::Value*, 4> conv_els;
conv_els.Reserve(composite->elements.Length());
@@ -344,7 +368,8 @@
}
for (auto* el : composite->elements) {
- auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source);
+ auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source,
+ use_runtime_semantics);
if (!conv_el) {
return utils::Failure;
}
@@ -359,33 +384,36 @@
ConstEval::Result ConvertInternal(const constant::Value* c,
ProgramBuilder& builder,
const type::Type* target_ty,
- const Source& source) {
+ const Source& source,
+ bool use_runtime_semantics) {
return Switch(
c,
[&](const constant::Scalar<tint::AFloat>* val) {
- return ScalarConvert(val, builder, target_ty, source);
+ return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
},
[&](const constant::Scalar<tint::AInt>* val) {
- return ScalarConvert(val, builder, target_ty, source);
+ return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
},
[&](const constant::Scalar<tint::u32>* val) {
- return ScalarConvert(val, builder, target_ty, source);
+ return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
},
[&](const constant::Scalar<tint::i32>* val) {
- return ScalarConvert(val, builder, target_ty, source);
+ return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
},
[&](const constant::Scalar<tint::f32>* val) {
- return ScalarConvert(val, builder, target_ty, source);
+ return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
},
[&](const constant::Scalar<tint::f16>* val) {
- return ScalarConvert(val, builder, target_ty, source);
+ return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
},
[&](const constant::Scalar<bool>* val) {
- return ScalarConvert(val, builder, target_ty, source);
+ return ScalarConvert(val, builder, target_ty, source, use_runtime_semantics);
},
- [&](const constant::Splat* val) { return SplatConvert(val, builder, target_ty, source); },
+ [&](const constant::Splat* val) {
+ return SplatConvert(val, builder, target_ty, source, use_runtime_semantics);
+ },
[&](const constant::Composite* val) {
- return CompositeConvert(val, builder, target_ty, source);
+ return CompositeConvert(val, builder, target_ty, source, use_runtime_semantics);
});
}
@@ -394,15 +422,21 @@
ConstEval::Result CreateScalar(ProgramBuilder& builder,
const Source& source,
const type::Type* t,
- T v) {
+ T v,
+ bool use_runtime_semantics) {
static_assert(IsNumber<T> || std::is_same_v<T, bool>, "T must be a Number or bool");
TINT_ASSERT(Resolver, t->is_scalar());
if constexpr (IsFloatingPoint<T>) {
if (!std::isfinite(v.value)) {
auto msg = OverflowErrorMessage(v, builder.FriendlyName(t));
- builder.Diagnostics().add_error(diag::System::Resolver, msg, source);
- return utils::Failure;
+ if (use_runtime_semantics) {
+ builder.Diagnostics().add_warning(diag::System::Resolver, msg, source);
+ return ZeroValue(builder, t);
+ } else {
+ builder.Diagnostics().add_error(diag::System::Resolver, msg, source);
+ return utils::Failure;
+ }
}
}
return builder.create<constant::Scalar<T>>(t, v);
@@ -448,7 +482,7 @@
},
[&](Default) -> const constant::Value* {
return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* {
- auto el = CreateScalar(builder, Source{}, type, zero);
+ auto el = CreateScalar(builder, Source{}, type, zero, false);
TINT_ASSERT(Resolver, el);
return el.Get();
});
@@ -543,7 +577,8 @@
}
} // namespace
-ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
+ConstEval::ConstEval(ProgramBuilder& b, bool use_runtime_semantics /* = false */)
+ : builder(b), use_runtime_semantics_(use_runtime_semantics) {}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Add(const Source& source, NumberT a, NumberT b) {
@@ -553,7 +588,11 @@
result = r->value;
} else {
AddError(OverflowErrorMessage(a, "+", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return NumberT{0};
+ } else {
+ return utils::Failure;
+ }
}
} else {
using T = UnwrapNumber<NumberT>;
@@ -579,7 +618,11 @@
result = r->value;
} else {
AddError(OverflowErrorMessage(a, "-", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return NumberT{0};
+ } else {
+ return utils::Failure;
+ }
}
} else {
using T = UnwrapNumber<NumberT>;
@@ -606,7 +649,11 @@
result = r->value;
} else {
AddError(OverflowErrorMessage(a, "*", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return NumberT{0};
+ } else {
+ return utils::Failure;
+ }
}
} else {
auto mul_values = [](T lhs, T rhs) {
@@ -631,7 +678,11 @@
result = r->value;
} else {
AddError(OverflowErrorMessage(a, "/", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return a;
+ } else {
+ return utils::Failure;
+ }
}
} else {
using T = UnwrapNumber<NumberT>;
@@ -640,14 +691,22 @@
if (rhs == 0) {
// For integers (as for floats), lhs / 0 is an error
AddError(OverflowErrorMessage(a, "/", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return a;
+ } else {
+ return utils::Failure;
+ }
}
if constexpr (std::is_signed_v<T>) {
// For signed integers, lhs / -1 where lhs is the
// most negative value is an error
if (rhs == -1 && lhs == std::numeric_limits<T>::min()) {
AddError(OverflowErrorMessage(a, "/", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return a;
+ } else {
+ return utils::Failure;
+ }
}
}
result = lhs / rhs;
@@ -663,7 +722,11 @@
result = r->value;
} else {
AddError(OverflowErrorMessage(a, "%", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return a;
+ } else {
+ return utils::Failure;
+ }
}
} else {
using T = UnwrapNumber<NumberT>;
@@ -672,14 +735,22 @@
if (rhs == 0) {
// lhs % 0 is an error
AddError(OverflowErrorMessage(a, "%", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return a;
+ } else {
+ return utils::Failure;
+ }
}
if constexpr (std::is_signed_v<T>) {
// For signed integers, lhs % -1 where lhs is the
// most negative value is an error
if (rhs == -1 && lhs == std::numeric_limits<T>::min()) {
AddError(OverflowErrorMessage(a, "%", b), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return a;
+ } else {
+ return utils::Failure;
+ }
}
}
result = lhs % rhs;
@@ -935,7 +1006,11 @@
utils::Result<NumberT> ConstEval::Sqrt(const Source& source, NumberT v) {
if (v < NumberT(0)) {
AddError("sqrt must be called with a value >= 0", source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return NumberT{0};
+ } else {
+ return utils::Failure;
+ }
}
return NumberT{std::sqrt(v)};
}
@@ -943,7 +1018,7 @@
auto ConstEval::SqrtFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto v) -> ConstEval::Result {
if (auto r = Sqrt(source, v)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -957,7 +1032,7 @@
auto ConstEval::ClampFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto e, auto low, auto high) -> ConstEval::Result {
if (auto r = Clamp(source, e, low, high)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -966,7 +1041,7 @@
auto ConstEval::AddFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Add(source, a1, a2)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -975,7 +1050,7 @@
auto ConstEval::SubFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Sub(source, a1, a2)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -984,7 +1059,7 @@
auto ConstEval::MulFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Mul(source, a1, a2)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -993,7 +1068,7 @@
auto ConstEval::DivFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Div(source, a1, a2)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -1002,7 +1077,7 @@
auto ConstEval::ModFunc(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2) -> ConstEval::Result {
if (auto r = Mod(source, a1, a2)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -1011,7 +1086,7 @@
auto ConstEval::Dot2Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2, auto b1, auto b2) -> ConstEval::Result {
if (auto r = Dot2(source, a1, a2, b1, b2)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -1020,7 +1095,7 @@
auto ConstEval::Dot3Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ConstEval::Result {
if (auto r = Dot3(source, a1, a2, a3, b1, b2, b3)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -1030,7 +1105,7 @@
return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3,
auto b4) -> ConstEval::Result {
if (auto r = Dot4(source, a1, a2, a3, a4, b1, b2, b3, b4)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -1071,7 +1146,7 @@
if (vec_ty == nullptr) {
auto create = [&](auto e) {
using NumberT = decltype(e);
- return CreateScalar(builder, source, ty, NumberT{std::abs(e)});
+ return CreateScalar(builder, source, ty, NumberT{std::abs(e)}, use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
}
@@ -1107,7 +1182,7 @@
auto ConstEval::Det2Func(const Source& source, const type::Type* elem_ty) {
return [=](auto a, auto b, auto c, auto d) -> ConstEval::Result {
if (auto r = Det2(source, a, b, c, d)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -1117,7 +1192,7 @@
return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h,
auto i) -> ConstEval::Result {
if (auto r = Det3(source, a, b, c, d, e, f, g, h, i)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -1127,7 +1202,7 @@
return [=](auto a, auto b, auto c, auto d, auto e, auto f, auto g, auto h, auto i, auto j,
auto k, auto l, auto m, auto n, auto o, auto p) -> ConstEval::Result {
if (auto r = Det4(source, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)) {
- return CreateScalar(builder, source, elem_ty, r.Get());
+ return CreateScalar(builder, source, elem_ty, r.Get(), use_runtime_semantics_);
}
return utils::Failure;
};
@@ -1138,27 +1213,33 @@
return Switch(
literal,
[&](const ast::BoolLiteralExpression* lit) {
- return CreateScalar(builder, source, ty, lit->value);
+ return CreateScalar(builder, source, ty, lit->value, use_runtime_semantics_);
},
[&](const ast::IntLiteralExpression* lit) -> ConstEval::Result {
switch (lit->suffix) {
case ast::IntLiteralExpression::Suffix::kNone:
- return CreateScalar(builder, source, ty, AInt(lit->value));
+ return CreateScalar(builder, source, ty, AInt(lit->value),
+ use_runtime_semantics_);
case ast::IntLiteralExpression::Suffix::kI:
- return CreateScalar(builder, source, ty, i32(lit->value));
+ return CreateScalar(builder, source, ty, i32(lit->value),
+ use_runtime_semantics_);
case ast::IntLiteralExpression::Suffix::kU:
- return CreateScalar(builder, source, ty, u32(lit->value));
+ return CreateScalar(builder, source, ty, u32(lit->value),
+ use_runtime_semantics_);
}
return nullptr;
},
[&](const ast::FloatLiteralExpression* lit) -> ConstEval::Result {
switch (lit->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone:
- return CreateScalar(builder, source, ty, AFloat(lit->value));
+ return CreateScalar(builder, source, ty, AFloat(lit->value),
+ use_runtime_semantics_);
case ast::FloatLiteralExpression::Suffix::kF:
- return CreateScalar(builder, source, ty, f32(lit->value));
+ return CreateScalar(builder, source, ty, f32(lit->value),
+ use_runtime_semantics_);
case ast::FloatLiteralExpression::Suffix::kH:
- return CreateScalar(builder, source, ty, f16(lit->value));
+ return CreateScalar(builder, source, ty, f16(lit->value),
+ use_runtime_semantics_);
}
return nullptr;
});
@@ -1277,7 +1358,8 @@
return builder.create<constant::Composite>(ty, args);
}
-ConstEval::Result ConstEval::Index(const sem::ValueExpression* obj_expr,
+ConstEval::Result ConstEval::Index(const type::Type* ty,
+ const sem::ValueExpression* obj_expr,
const sem::ValueExpression* idx_expr) {
auto idx_val = idx_expr->ConstantValue();
if (!idx_val) {
@@ -1295,7 +1377,11 @@
}
AddError("index " + std::to_string(idx) + " out of bounds" + range,
idx_expr->Declaration()->source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, ty);
+ } else {
+ return utils::Failure;
+ }
}
auto obj_val = obj_expr->ConstantValue();
@@ -1340,15 +1426,15 @@
el_ty,
[&](const type::U32*) { //
auto r = utils::Bitcast<u32>(e);
- return CreateScalar(builder, source, el_ty, r);
+ return CreateScalar(builder, source, el_ty, r, use_runtime_semantics_);
},
[&](const type::I32*) { //
auto r = utils::Bitcast<i32>(e);
- return CreateScalar(builder, source, el_ty, r);
+ return CreateScalar(builder, source, el_ty, r, use_runtime_semantics_);
},
[&](const type::F32*) { //
auto r = utils::Bitcast<f32>(e);
- return CreateScalar(builder, source, el_ty, r);
+ return CreateScalar(builder, source, el_ty, r, use_runtime_semantics_);
});
};
return Dispatch_fiu32(create, c0);
@@ -1361,7 +1447,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c) {
auto create = [&](auto i) {
- return CreateScalar(builder, source, c->Type(), decltype(i)(~i.value));
+ return CreateScalar(builder, source, c->Type(), decltype(i)(~i.value),
+ use_runtime_semantics_);
};
return Dispatch_ia_iu32(create, c);
};
@@ -1383,9 +1470,11 @@
if (v != std::numeric_limits<T>::min()) {
v = -v;
}
- return CreateScalar(builder, source, c->Type(), decltype(i)(v));
+ return CreateScalar(builder, source, c->Type(), decltype(i)(v),
+ use_runtime_semantics_);
} else {
- return CreateScalar(builder, source, c->Type(), decltype(i)(-i.value));
+ return CreateScalar(builder, source, c->Type(), decltype(i)(-i.value),
+ use_runtime_semantics_);
}
};
return Dispatch_fia_fi32_f16(create, c);
@@ -1398,7 +1487,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c) {
auto create = [&](auto i) {
- return CreateScalar(builder, source, c->Type(), decltype(i)(!i));
+ return CreateScalar(builder, source, c->Type(), decltype(i)(!i),
+ use_runtime_semantics_);
};
return Dispatch_bool(create, c);
};
@@ -1617,7 +1707,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i == j);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i == j,
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
@@ -1630,7 +1721,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i != j);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i != j,
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
@@ -1643,7 +1735,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i < j);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i < j,
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -1656,7 +1749,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i > j);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i > j,
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -1669,7 +1763,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i <= j);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i <= j,
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -1682,7 +1777,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i >= j);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i >= j,
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -1695,7 +1791,8 @@
const Source& source) {
// Note: Due to short-circuiting, this function is only called if lhs is true, so we could
// technically only return the value of the rhs.
- return CreateScalar(builder, source, ty, args[0]->ValueAs<bool>() && args[1]->ValueAs<bool>());
+ return CreateScalar(builder, source, ty, args[0]->ValueAs<bool>() && args[1]->ValueAs<bool>(),
+ use_runtime_semantics_);
}
ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty,
@@ -1703,7 +1800,7 @@
const Source& source) {
// Note: Due to short-circuiting, this function is only called if lhs is false, so we could
// technically only return the value of the rhs.
- return CreateScalar(builder, source, ty, args[1]->ValueAs<bool>());
+ return CreateScalar(builder, source, ty, args[1]->ValueAs<bool>(), use_runtime_semantics_);
}
ConstEval::Result ConstEval::OpAnd(const type::Type* ty,
@@ -1718,7 +1815,8 @@
} else { // integral
result = i & j;
}
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), result);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), result,
+ use_runtime_semantics_);
};
return Dispatch_ia_iu32_bool(create, c0, c1);
};
@@ -1738,7 +1836,8 @@
} else { // integral
result = i | j;
}
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), result);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), result,
+ use_runtime_semantics_);
};
return Dispatch_ia_iu32_bool(create, c0, c1);
};
@@ -1752,7 +1851,7 @@
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) -> ConstEval::Result {
return CreateScalar(builder, source, type::Type::DeepestElementOf(ty),
- decltype(i){i ^ j});
+ decltype(i){i ^ j}, use_runtime_semantics_);
};
return Dispatch_ia_iu32(create, c0, c1);
};
@@ -1782,18 +1881,22 @@
UT mask = ~UT{0} << (bit_width - must_match_msb);
if ((e1u & mask) != 0 && (e1u & mask) != mask) {
AddError("shift left operation results in sign change", source);
- return utils::Failure;
+ if (!use_runtime_semantics_) {
+ return utils::Failure;
+ }
}
} else {
// If shift value >= bit_width, then any non-zero value would overflow
if (e1 != 0) {
AddError(OverflowErrorMessage(e1, "<<", e2), source);
- return utils::Failure;
+ if (!use_runtime_semantics_) {
+ return utils::Failure;
+ }
}
// It's UB in C++ to shift by greater or equal to the bit width (even if the lhs
// is 0), so we make sure to avoid this by setting the shift value to 0.
- e2 = 0;
+ e2u = 0;
}
} else {
if (static_cast<size_t>(e2) >= bit_width) {
@@ -1804,7 +1907,11 @@
"shift left value must be less than the bit width of the lhs, which is " +
std::to_string(bit_width),
source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ e2u = e2u % bit_width;
+ } else {
+ return utils::Failure;
+ }
}
if constexpr (std::is_signed_v<T>) {
@@ -1814,7 +1921,9 @@
UT mask = ~UT{0} << (bit_width - must_match_msb);
if ((e1u & mask) != 0 && (e1u & mask) != mask) {
AddError("shift left operation results in sign change", source);
- return utils::Failure;
+ if (!use_runtime_semantics_) {
+ return utils::Failure;
+ }
}
} else {
// If T is an unsigned integer type, and any of the e2 most significant bits of
@@ -1824,15 +1933,18 @@
UT mask = ~UT{0} << (bit_width - must_be_zero_msb);
if ((e1u & mask) != 0) {
AddError(OverflowErrorMessage(e1, "<<", e2), source);
- return utils::Failure;
+ if (!use_runtime_semantics_) {
+ return utils::Failure;
+ }
}
}
}
}
// Avoid UB by left shifting as unsigned value
- auto result = static_cast<T>(static_cast<UT>(e1) << e2);
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), NumberT{result});
+ auto result = static_cast<T>(static_cast<UT>(e1) << e2u);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), NumberT{result},
+ use_runtime_semantics_);
};
return Dispatch_ia_iu32(create, c0, c1);
};
@@ -1855,8 +1967,8 @@
using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>;
const size_t bit_width = BitWidth<NumberT>;
- const UT e1u = static_cast<UT>(e1);
- const UT e2u = static_cast<UT>(e2);
+ UT e1u = static_cast<UT>(e1);
+ UT e2u = static_cast<UT>(e2);
auto signed_shift_right = [&] {
// In C++, right shift of a signed negative number is implementation-defined.
@@ -1887,16 +1999,21 @@
"shift right value must be less than the bit width of the lhs, which is " +
std::to_string(bit_width),
source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ e2u = e2u % bit_width;
+ } else {
+ return utils::Failure;
+ }
}
if constexpr (std::is_signed_v<T>) {
result = signed_shift_right();
} else {
- result = e1 >> e2;
+ result = e1 >> e2u;
}
}
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), NumberT{result});
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), NumberT{result},
+ use_runtime_semantics_);
};
return Dispatch_ia_iu32(create, c0, c1);
};
@@ -1928,7 +2045,7 @@
} else {
result = NumberT{std::abs(e)};
}
- return CreateScalar(builder, source, c0->Type(), result);
+ return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16(create, c0);
};
@@ -1944,9 +2061,14 @@
if (i < NumberT(-1.0) || i > NumberT(1.0)) {
AddError("acos must be called with a value in the range [-1 .. 1] (inclusive)",
source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), NumberT(std::acos(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::acos(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -1961,9 +2083,14 @@
using NumberT = decltype(i);
if (i < NumberT(1.0)) {
AddError("acosh must be called with a value >= 1.0", source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), NumberT(std::acosh(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::acosh(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -1974,13 +2101,13 @@
ConstEval::Result ConstEval::all(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- return CreateScalar(builder, source, ty, !args[0]->AnyZero());
+ return CreateScalar(builder, source, ty, !args[0]->AnyZero(), use_runtime_semantics_);
}
ConstEval::Result ConstEval::any(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source& source) {
- return CreateScalar(builder, source, ty, !args[0]->AllZero());
+ return CreateScalar(builder, source, ty, !args[0]->AllZero(), use_runtime_semantics_);
}
ConstEval::Result ConstEval::asin(const type::Type* ty,
@@ -1992,9 +2119,14 @@
if (i < NumberT(-1.0) || i > NumberT(1.0)) {
AddError("asin must be called with a value in the range [-1 .. 1] (inclusive)",
source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), NumberT(std::asin(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::asin(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2006,7 +2138,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) {
- return CreateScalar(builder, source, c0->Type(), decltype(i)(std::asinh(i.value)));
+ return CreateScalar(builder, source, c0->Type(), decltype(i)(std::asinh(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2019,7 +2152,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) {
- return CreateScalar(builder, source, c0->Type(), decltype(i)(std::atan(i.value)));
+ return CreateScalar(builder, source, c0->Type(), decltype(i)(std::atan(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2035,9 +2169,14 @@
if (i <= NumberT(-1.0) || i >= NumberT(1.0)) {
AddError("atanh must be called with a value in the range (-1 .. 1) (exclusive)",
source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), NumberT(std::atanh(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::atanh(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2051,7 +2190,7 @@
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto i, auto j) {
return CreateScalar(builder, source, c0->Type(),
- decltype(i)(std::atan2(i.value, j.value)));
+ decltype(i)(std::atan2(i.value, j.value)), use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0, c1);
};
@@ -2063,7 +2202,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto e) {
- return CreateScalar(builder, source, c0->Type(), decltype(e)(std::ceil(e)));
+ return CreateScalar(builder, source, c0->Type(), decltype(e)(std::ceil(e)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2086,7 +2226,8 @@
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i);
- return CreateScalar(builder, source, c0->Type(), NumberT(std::cos(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::cos(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2099,7 +2240,8 @@
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i);
- return CreateScalar(builder, source, c0->Type(), NumberT(std::cosh(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::cosh(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2114,7 +2256,8 @@
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
auto count = CountLeadingBits(T{e}, T{0});
- return CreateScalar(builder, source, c0->Type(), NumberT(count));
+ return CreateScalar(builder, source, c0->Type(), NumberT(count),
+ use_runtime_semantics_);
};
return Dispatch_iu32(create, c0);
};
@@ -2138,7 +2281,8 @@
}
}
- return CreateScalar(builder, source, c0->Type(), NumberT(count));
+ return CreateScalar(builder, source, c0->Type(), NumberT(count),
+ use_runtime_semantics_);
};
return Dispatch_iu32(create, c0);
};
@@ -2153,7 +2297,8 @@
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
auto count = CountTrailingBits(T{e}, T{0});
- return CreateScalar(builder, source, c0->Type(), NumberT(count));
+ return CreateScalar(builder, source, c0->Type(), NumberT(count),
+ use_runtime_semantics_);
};
return Dispatch_iu32(create, c0);
};
@@ -2222,7 +2367,7 @@
AddNote("when calculating degrees", source);
return utils::Failure;
}
- return CreateScalar(builder, source, c0->Type(), result.Get());
+ return CreateScalar(builder, source, c0->Type(), result.Get(), use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2304,9 +2449,13 @@
auto val = NumberT(std::exp(e0));
if (!std::isfinite(val.value)) {
AddError(OverflowExpErrorMessage("e", e0), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), val);
+ return CreateScalar(builder, source, c0->Type(), val, use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2322,9 +2471,13 @@
auto val = NumberT(std::exp2(e0));
if (!std::isfinite(val.value)) {
AddError(OverflowExpErrorMessage("2", e0), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), val);
+ return CreateScalar(builder, source, c0->Type(), val, use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2354,7 +2507,12 @@
if (o > w || c > w || (o + c) > w) {
AddError("'offset + 'count' must be less than or equal to the bit width of 'e'",
source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ o = std::min(o, w);
+ c = std::min(c, w - o);
+ } else {
+ return utils::Failure;
+ }
}
NumberT result;
@@ -2379,7 +2537,7 @@
result = NumberT{r};
}
- return CreateScalar(builder, source, c0->Type(), result);
+ return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_);
};
return Dispatch_iu32(create, c0);
};
@@ -2442,7 +2600,7 @@
}
}
- return CreateScalar(builder, source, c0->Type(), result);
+ return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_);
};
return Dispatch_iu32(create, c0);
};
@@ -2468,7 +2626,7 @@
result = NumberT(pos);
}
- return CreateScalar(builder, source, c0->Type(), result);
+ return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_);
};
return Dispatch_iu32(create, c0);
};
@@ -2480,7 +2638,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto e) {
- return CreateScalar(builder, source, c0->Type(), decltype(e)(std::floor(e)));
+ return CreateScalar(builder, source, c0->Type(), decltype(e)(std::floor(e)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2507,7 +2666,7 @@
if (!val) {
return err_msg();
}
- return CreateScalar(builder, source, c1->Type(), val.Get());
+ return CreateScalar(builder, source, c1->Type(), val.Get(), use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c1, c2, c3);
};
@@ -2521,7 +2680,7 @@
auto create = [&](auto e) -> ConstEval::Result {
using NumberT = decltype(e);
auto r = e - std::floor(e);
- return CreateScalar(builder, source, c1->Type(), NumberT{r});
+ return CreateScalar(builder, source, c1->Type(), NumberT{r}, use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c1);
};
@@ -2545,21 +2704,26 @@
s->Type(),
[&](const type::F32*) {
return FractExp{
- CreateScalar(builder, source, builder.create<type::F32>(), f32(fract)),
- CreateScalar(builder, source, builder.create<type::I32>(), i32(exp)),
+ CreateScalar(builder, source, builder.create<type::F32>(), f32(fract),
+ use_runtime_semantics_),
+ CreateScalar(builder, source, builder.create<type::I32>(), i32(exp),
+ use_runtime_semantics_),
};
},
[&](const type::F16*) {
return FractExp{
- CreateScalar(builder, source, builder.create<type::F16>(), f16(fract)),
- CreateScalar(builder, source, builder.create<type::I32>(), i32(exp)),
+ CreateScalar(builder, source, builder.create<type::F16>(), f16(fract),
+ use_runtime_semantics_),
+ CreateScalar(builder, source, builder.create<type::I32>(), i32(exp),
+ use_runtime_semantics_),
};
},
[&](const type::AbstractFloat*) {
return FractExp{
CreateScalar(builder, source, builder.create<type::AbstractFloat>(),
- AFloat(fract)),
- CreateScalar(builder, source, builder.create<type::AbstractInt>(), AInt(exp)),
+ AFloat(fract), use_runtime_semantics_),
+ CreateScalar(builder, source, builder.create<type::AbstractInt>(), AInt(exp),
+ use_runtime_semantics_),
};
},
[&](Default) {
@@ -2624,7 +2788,12 @@
if (o > w || c > w || (o + c) > w) {
AddError("'offset + 'count' must be less than or equal to the bit width of 'e'",
source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ o = std::min(o, w);
+ c = std::min(c, w - o);
+ } else {
+ return utils::Failure;
+ }
}
NumberT result;
@@ -2645,7 +2814,7 @@
result = NumberT{r};
}
- return CreateScalar(builder, source, c0->Type(), result);
+ return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_);
};
return Dispatch_iu32(create, c0, c1);
};
@@ -2661,7 +2830,11 @@
if (e <= NumberT(0)) {
AddError("inverseSqrt must be called with a value > 0", source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
auto err = [&] {
@@ -2678,7 +2851,7 @@
return err();
}
- return CreateScalar(builder, source, c0->Type(), div.Get());
+ return CreateScalar(builder, source, c0->Type(), div.Get(), use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2714,13 +2887,17 @@
if (e2 > bias + 1) {
AddError("e2 must be less than or equal to " + std::to_string(bias + 1), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c1->Type());
+ } else {
+ return utils::Failure;
+ }
}
auto target_ty = type::Type::DeepestElementOf(ty);
auto r = std::ldexp(e1, static_cast<int>(e2));
- return CreateScalar(builder, source, target_ty, E1Type{r});
+ return CreateScalar(builder, source, target_ty, E1Type{r}, use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c1);
};
@@ -2746,9 +2923,14 @@
using NumberT = decltype(v);
if (v <= NumberT(0)) {
AddError("log must be called with a value > 0", source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), NumberT(std::log(v)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::log(v)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2763,9 +2945,14 @@
using NumberT = decltype(v);
if (v <= NumberT(0)) {
AddError("log2 must be called with a value > 0", source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), NumberT(std::log2(v)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::log2(v)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -2777,7 +2964,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto e0, auto e1) {
- return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::max(e0, e1)));
+ return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::max(e0, e1)),
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -2789,7 +2977,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto e0, auto e1) {
- return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::min(e0, e1)));
+ return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::min(e0, e1)),
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
@@ -2828,7 +3017,7 @@
if (!r) {
return utils::Failure;
}
- return CreateScalar(builder, source, c0->Type(), r.Get());
+ return CreateScalar(builder, source, c0->Type(), r.Get(), use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0, c1);
};
@@ -2845,13 +3034,14 @@
auto transform_fract = [&](const constant::Value* c) {
auto create = [&](auto e) {
return CreateScalar(builder, source, c->Type(),
- decltype(e)(e.value - std::trunc(e.value)));
+ decltype(e)(e.value - std::trunc(e.value)), use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c);
};
auto transform_whole = [&](const constant::Value* c) {
auto create = [&](auto e) {
- return CreateScalar(builder, source, c->Type(), decltype(e)(std::trunc(e.value)));
+ return CreateScalar(builder, source, c->Type(), decltype(e)(std::trunc(e.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c);
};
@@ -2885,7 +3075,11 @@
auto* v = len.Get();
if (v->AllZero()) {
AddError("zero length vector can not be normalized", source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, ty);
+ } else {
+ return utils::Failure;
+ }
}
return OpDivide(ty, utils::Vector{args[0], v}, source);
}
@@ -2897,7 +3091,11 @@
auto conv = CheckedConvert<f16>(val);
if (!conv) {
AddError(OverflowErrorMessage(val, "f16"), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return 0;
+ } else {
+ return utils::Failure;
+ }
}
uint16_t v = conv.Get().BitsRepresentation();
return utils::Result<uint32_t>{v};
@@ -2915,7 +3113,7 @@
}
u32 ret = u32((e0.Get() & 0x0000'ffff) | (e1.Get() << 16));
- return CreateScalar(builder, source, ty, ret);
+ return CreateScalar(builder, source, ty, ret, use_runtime_semantics_);
}
ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty,
@@ -2932,7 +3130,7 @@
auto e1 = calc(e->Index(1)->ValueAs<f32>());
u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16));
- return CreateScalar(builder, source, ty, ret);
+ return CreateScalar(builder, source, ty, ret, use_runtime_semantics_);
}
ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty,
@@ -2948,7 +3146,7 @@
auto e1 = calc(e->Index(1)->ValueAs<f32>());
u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16));
- return CreateScalar(builder, source, ty, ret);
+ return CreateScalar(builder, source, ty, ret, use_runtime_semantics_);
}
ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty,
@@ -2968,7 +3166,7 @@
uint32_t mask = 0x0000'00ff;
u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24));
- return CreateScalar(builder, source, ty, ret);
+ return CreateScalar(builder, source, ty, ret, use_runtime_semantics_);
}
ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty,
@@ -2987,7 +3185,7 @@
uint32_t mask = 0x0000'00ff;
u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24));
- return CreateScalar(builder, source, ty, ret);
+ return CreateScalar(builder, source, ty, ret, use_runtime_semantics_);
}
ConstEval::Result ConstEval::pow(const type::Type* ty,
@@ -2998,9 +3196,13 @@
auto r = CheckedPow(e1, e2);
if (!r) {
AddError(OverflowErrorMessage(e1, "^", e2), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c0->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c0->Type(), *r);
+ return CreateScalar(builder, source, c0->Type(), *r, use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0, c1);
};
@@ -3026,7 +3228,7 @@
AddNote("when calculating radians", source);
return utils::Failure;
}
- return CreateScalar(builder, source, c0->Type(), result.Get());
+ return CreateScalar(builder, source, c0->Type(), result.Get(), use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -3053,7 +3255,8 @@
// 2 * dot(e2, e1)
auto mul2 = [&](auto v) -> ConstEval::Result {
using NumberT = decltype(v);
- return CreateScalar(builder, source, el_ty, NumberT{NumberT{2} * v});
+ return CreateScalar(builder, source, el_ty, NumberT{NumberT{2} * v},
+ use_runtime_semantics_);
};
auto dot_e2_e1_2 = Dispatch_fa_f32_f16(mul2, dot_e2_e1.Get());
if (!dot_e2_e1_2) {
@@ -3105,7 +3308,7 @@
if (!r) {
return utils::Failure;
}
- return CreateScalar(builder, source, el_ty, r.Get());
+ return CreateScalar(builder, source, el_ty, r.Get(), use_runtime_semantics_);
};
auto compute_e2_scale = [&](auto e3, auto dot_e2_e1, auto k) -> ConstEval::Result {
@@ -3122,7 +3325,7 @@
if (!r) {
return utils::Failure;
}
- return CreateScalar(builder, source, el_ty, r.Get());
+ return CreateScalar(builder, source, el_ty, r.Get(), use_runtime_semantics_);
};
auto calculate = [&]() -> ConstEval::Result {
@@ -3194,7 +3397,7 @@
}
}
- return CreateScalar(builder, source, c0->Type(), NumberT{r});
+ return CreateScalar(builder, source, c0->Type(), NumberT{r}, use_runtime_semantics_);
};
return Dispatch_iu32(create, c0);
};
@@ -3230,7 +3433,7 @@
} else {
result = NumberT(std::round(e.value));
}
- return CreateScalar(builder, source, c0->Type(), result);
+ return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -3244,7 +3447,8 @@
auto create = [&](auto e) {
using NumberT = decltype(e);
return CreateScalar(builder, source, c0->Type(),
- NumberT(std::min(std::max(e, NumberT(0.0)), NumberT(1.0))));
+ NumberT(std::min(std::max(e, NumberT(0.0)), NumberT(1.0))),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -3257,7 +3461,8 @@
auto cond = args[2]->ValueAs<bool>();
auto transform = [&](const constant::Value* c0, const constant::Value* c1) {
auto create = [&](auto f, auto t) -> ConstEval::Result {
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f,
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
@@ -3272,7 +3477,8 @@
auto create = [&](auto f, auto t) -> ConstEval::Result {
// Get corresponding bool value at the current vector value index
auto cond = args[2]->Index(index)->ValueAs<bool>();
- return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f);
+ return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f,
+ use_runtime_semantics_);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
@@ -3295,7 +3501,7 @@
} else {
result = zero;
}
- return CreateScalar(builder, source, c0->Type(), result);
+ return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_);
};
return Dispatch_fia_fi32_f16(create, c0);
};
@@ -3308,7 +3514,8 @@
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i);
- return CreateScalar(builder, source, c0->Type(), NumberT(std::sin(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::sin(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -3321,7 +3528,8 @@
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i);
- return CreateScalar(builder, source, c0->Type(), NumberT(std::sinh(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::sinh(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -3372,7 +3580,7 @@
if (!result) {
return err();
}
- return CreateScalar(builder, source, c0->Type(), result.Get());
+ return CreateScalar(builder, source, c0->Type(), result.Get(), use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0, c1, c2);
};
@@ -3386,7 +3594,7 @@
auto create = [&](auto edge, auto x) -> ConstEval::Result {
using NumberT = decltype(edge);
NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0);
- return CreateScalar(builder, source, c0->Type(), result);
+ return CreateScalar(builder, source, c0->Type(), result, use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0, c1);
};
@@ -3409,7 +3617,8 @@
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i);
- return CreateScalar(builder, source, c0->Type(), NumberT(std::tan(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::tan(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -3422,7 +3631,8 @@
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) -> ConstEval::Result {
using NumberT = decltype(i);
- return CreateScalar(builder, source, c0->Type(), NumberT(std::tanh(i.value)));
+ return CreateScalar(builder, source, c0->Type(), NumberT(std::tanh(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -3455,7 +3665,8 @@
const Source& source) {
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto i) {
- return CreateScalar(builder, source, c0->Type(), decltype(i)(std::trunc(i.value)));
+ return CreateScalar(builder, source, c0->Type(), decltype(i)(std::trunc(i.value)),
+ use_runtime_semantics_);
};
return Dispatch_fa_f32_f16(create, c0);
};
@@ -3475,9 +3686,13 @@
auto val = CheckedConvert<f32>(in);
if (!val) {
AddError(OverflowErrorMessage(in, "f32"), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ val = f32(0.f);
+ } else {
+ return utils::Failure;
+ }
}
- auto el = CreateScalar(builder, source, inner_ty, val.Get());
+ auto el = CreateScalar(builder, source, inner_ty, val.Get(), use_runtime_semantics_);
if (!el) {
return el;
}
@@ -3497,7 +3712,7 @@
for (size_t i = 0; i < 2; ++i) {
auto val = f32(
std::max(static_cast<float>(int16_t((e >> (16 * i)) & 0x0000'ffff)) / 32767.f, -1.f));
- auto el = CreateScalar(builder, source, inner_ty, val);
+ auto el = CreateScalar(builder, source, inner_ty, val, use_runtime_semantics_);
if (!el) {
return el;
}
@@ -3516,7 +3731,7 @@
els.Reserve(2);
for (size_t i = 0; i < 2; ++i) {
auto val = f32(static_cast<float>(uint16_t((e >> (16 * i)) & 0x0000'ffff)) / 65535.f);
- auto el = CreateScalar(builder, source, inner_ty, val);
+ auto el = CreateScalar(builder, source, inner_ty, val, use_runtime_semantics_);
if (!el) {
return el;
}
@@ -3536,7 +3751,7 @@
for (size_t i = 0; i < 4; ++i) {
auto val =
f32(std::max(static_cast<float>(int8_t((e >> (8 * i)) & 0x0000'00ff)) / 127.f, -1.f));
- auto el = CreateScalar(builder, source, inner_ty, val);
+ auto el = CreateScalar(builder, source, inner_ty, val, use_runtime_semantics_);
if (!el) {
return el;
}
@@ -3555,7 +3770,7 @@
els.Reserve(4);
for (size_t i = 0; i < 4; ++i) {
auto val = f32(static_cast<float>(uint8_t((e >> (8 * i)) & 0x0000'00ff)) / 255.f);
- auto el = CreateScalar(builder, source, inner_ty, val);
+ auto el = CreateScalar(builder, source, inner_ty, val, use_runtime_semantics_);
if (!el) {
return el;
}
@@ -3572,9 +3787,13 @@
auto conv = CheckedConvert<f32>(f16(value));
if (!conv) {
AddError(OverflowErrorMessage(value, "f16"), source);
- return utils::Failure;
+ if (use_runtime_semantics_) {
+ return ZeroValue(builder, c->Type());
+ } else {
+ return utils::Failure;
+ }
}
- return CreateScalar(builder, source, c->Type(), conv.Get());
+ return CreateScalar(builder, source, c->Type(), conv.Get(), use_runtime_semantics_);
};
return TransformElements(builder, ty, transform, args[0]);
}
@@ -3585,11 +3804,15 @@
if (value->Type() == target_ty) {
return value;
}
- return ConvertInternal(value, builder, target_ty, source);
+ return ConvertInternal(value, builder, target_ty, source, use_runtime_semantics_);
}
void ConstEval::AddError(const std::string& msg, const Source& source) const {
- builder.Diagnostics().add_error(diag::System::Resolver, msg, source);
+ if (use_runtime_semantics_) {
+ builder.Diagnostics().add_warning(diag::System::Resolver, msg, source);
+ } else {
+ builder.Diagnostics().add_error(diag::System::Resolver, msg, source);
+ }
}
void ConstEval::AddWarning(const std::string& msg, const Source& source) const {
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index 6c26511..df91cc0 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -68,7 +68,9 @@
/// Constructor
/// @param b the program builder
- explicit ConstEval(ProgramBuilder& b);
+ /// @param use_runtime_semantics if `true`, use the behavior defined for runtime evaluation, and
+ /// emit overflow and range errors as warnings instead of errors
+ explicit ConstEval(ProgramBuilder& b, bool use_runtime_semantics = false);
////////////////////////////////////////////////////////////////////////////////////////////////
// Constant value evaluation methods, to be called directly from Resolver
@@ -87,10 +89,13 @@
/// be calculated
Result Bitcast(const type::Type* ty, const constant::Value* value, const Source& source);
+ /// @param ty the target type
/// @param obj the object being indexed
/// @param idx the index expression
/// @return the result of the index, or null if the value cannot be calculated
- Result Index(const sem::ValueExpression* obj, const sem::ValueExpression* idx);
+ Result Index(const type::Type* ty,
+ const sem::ValueExpression* obj,
+ const sem::ValueExpression* idx);
/// @param ty the result type
/// @param lit the literal AST node
@@ -1404,6 +1409,7 @@
const constant::Value* v2);
ProgramBuilder& builder;
+ bool use_runtime_semantics_ = false;
};
} // namespace tint::resolver
diff --git a/src/tint/resolver/const_eval_runtime_semantics_test.cc b/src/tint/resolver/const_eval_runtime_semantics_test.cc
new file mode 100644
index 0000000..9de1866
--- /dev/null
+++ b/src/tint/resolver/const_eval_runtime_semantics_test.cc
@@ -0,0 +1,589 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/tint/resolver/const_eval_test.h"
+
+#include "src/tint/constant/scalar.h"
+
+using namespace tint::number_suffixes; // NOLINT
+
+namespace tint::resolver {
+namespace {
+
+class ResolverConstEvalRuntimeSemanticsTest : public ResolverConstEvalTest {
+ protected:
+ /// Default constructor.
+ ResolverConstEvalRuntimeSemanticsTest()
+ : const_eval(ConstEval(*this, /* use_runtime_semantics */ true)) {}
+
+ /// The ConstEval object used during testing (has runtime semantics enabled).
+ ConstEval const_eval;
+
+ /// @returns the contents of the diagnostics list as a string
+ std::string error() {
+ diag::Formatter::Style style{};
+ style.print_newline_at_end = false;
+ diag::Formatter formatter{style};
+ return formatter.format(Diagnostics());
+ }
+
+ /// Helper to make a scalar constant::Value from a value.
+ template <typename T>
+ const constant::Value* Scalar(T value) {
+ if constexpr (IsAbstract<T>) {
+ if constexpr (IsFloatingPoint<T>) {
+ return create<constant::Scalar<AFloat>>(create<type::AbstractFloat>(), value);
+ } else if constexpr (IsIntegral<T>) {
+ return create<constant::Scalar<AInt>>(create<type::AbstractInt>(), value);
+ }
+ } else if constexpr (IsFloatingPoint<T>) {
+ return create<constant::Scalar<f32>>(create<type::F32>(), value);
+ } else if constexpr (IsSignedIntegral<T>) {
+ return create<constant::Scalar<i32>>(create<type::I32>(), value);
+ } else if constexpr (IsUnsignedIntegral<T>) {
+ return create<constant::Scalar<u32>>(create<type::U32>(), value);
+ }
+ }
+};
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_AInt_Overflow) {
+ auto* a = Scalar(AInt::Highest());
+ auto* b = Scalar(AInt(1));
+ auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AInt>(), 0);
+ EXPECT_EQ(error(),
+ R"(warning: '9223372036854775807 + 1' cannot be represented as 'abstract-int')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_AFloat_Overflow) {
+ auto* a = Scalar(AFloat::Highest());
+ auto* b = Scalar(AFloat::Highest());
+ auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 0.f);
+ EXPECT_EQ(
+ error(),
+ R"(warning: '1.7976931348623157081e+308 + 1.7976931348623157081e+308' cannot be represented as 'abstract-float')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_F32_Overflow) {
+ auto* a = Scalar(f32::Highest());
+ auto* b = Scalar(f32::Highest());
+ auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(
+ error(),
+ R"(warning: '3.4028234663852885981e+38 + 3.4028234663852885981e+38' cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_AInt_Overflow) {
+ auto* a = Scalar(AInt::Lowest());
+ auto* b = Scalar(AInt(1));
+ auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AInt>(), 0);
+ EXPECT_EQ(error(),
+ R"(warning: '-9223372036854775808 - 1' cannot be represented as 'abstract-int')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_AFloat_Overflow) {
+ auto* a = Scalar(AFloat::Lowest());
+ auto* b = Scalar(AFloat::Highest());
+ auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 0.f);
+ EXPECT_EQ(
+ error(),
+ R"(warning: '-1.7976931348623157081e+308 - 1.7976931348623157081e+308' cannot be represented as 'abstract-float')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_F32_Overflow) {
+ auto* a = Scalar(f32::Lowest());
+ auto* b = Scalar(f32::Highest());
+ auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(
+ error(),
+ R"(warning: '-3.4028234663852885981e+38 - 3.4028234663852885981e+38' cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_AInt_Overflow) {
+ auto* a = Scalar(AInt::Highest());
+ auto* b = Scalar(AInt(2));
+ auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AInt>(), 0);
+ EXPECT_EQ(error(),
+ R"(warning: '9223372036854775807 * 2' cannot be represented as 'abstract-int')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_AFloat_Overflow) {
+ auto* a = Scalar(AFloat::Highest());
+ auto* b = Scalar(AFloat::Highest());
+ auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 0.f);
+ EXPECT_EQ(
+ error(),
+ R"(warning: '1.7976931348623157081e+308 * 1.7976931348623157081e+308' cannot be represented as 'abstract-float')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_F32_Overflow) {
+ auto* a = Scalar(f32::Highest());
+ auto* b = Scalar(f32::Highest());
+ auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(
+ error(),
+ R"(warning: '3.4028234663852885981e+38 * 3.4028234663852885981e+38' cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_AInt_ZeroDenominator) {
+ auto* a = Scalar(AInt(42));
+ auto* b = Scalar(AInt(0));
+ auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AInt>(), 42);
+ EXPECT_EQ(error(), R"(warning: '42 / 0' cannot be represented as 'abstract-int')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_I32_ZeroDenominator) {
+ auto* a = Scalar(i32(42));
+ auto* b = Scalar(i32(0));
+ auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), 42);
+ EXPECT_EQ(error(), R"(warning: '42 / 0' cannot be represented as 'i32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_U32_ZeroDenominator) {
+ auto* a = Scalar(u32(42));
+ auto* b = Scalar(u32(0));
+ auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<u32>(), 42);
+ EXPECT_EQ(error(), R"(warning: '42 / 0' cannot be represented as 'u32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_AFloat_ZeroDenominator) {
+ auto* a = Scalar(AFloat(42));
+ auto* b = Scalar(AFloat(0));
+ auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 42.f);
+ EXPECT_EQ(error(), R"(warning: '42 / 0' cannot be represented as 'abstract-float')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_F32_ZeroDenominator) {
+ auto* a = Scalar(f32(42));
+ auto* b = Scalar(f32(0));
+ auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 42.f);
+ EXPECT_EQ(error(), R"(warning: '42 / 0' cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_I32_MostNegativeByMinInt) {
+ auto* a = Scalar(i32::Lowest());
+ auto* b = Scalar(i32(-1));
+ auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), i32::Lowest());
+ EXPECT_EQ(error(), R"(warning: '-2147483648 / -1' cannot be represented as 'i32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_AInt_ZeroDenominator) {
+ auto* a = Scalar(AInt(42));
+ auto* b = Scalar(AInt(0));
+ auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AInt>(), 42);
+ EXPECT_EQ(error(), R"(warning: '42 % 0' cannot be represented as 'abstract-int')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_I32_ZeroDenominator) {
+ auto* a = Scalar(i32(42));
+ auto* b = Scalar(i32(0));
+ auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), 42);
+ EXPECT_EQ(error(), R"(warning: '42 % 0' cannot be represented as 'i32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_U32_ZeroDenominator) {
+ auto* a = Scalar(u32(42));
+ auto* b = Scalar(u32(0));
+ auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<u32>(), 42);
+ EXPECT_EQ(error(), R"(warning: '42 % 0' cannot be represented as 'u32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_AFloat_ZeroDenominator) {
+ auto* a = Scalar(AFloat(42));
+ auto* b = Scalar(AFloat(0));
+ auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AFloat>(), 42.f);
+ EXPECT_EQ(error(), R"(warning: '42 % 0' cannot be represented as 'abstract-float')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_F32_ZeroDenominator) {
+ auto* a = Scalar(f32(42));
+ auto* b = Scalar(f32(0));
+ auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 42.f);
+ EXPECT_EQ(error(), R"(warning: '42 % 0' cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_I32_MostNegativeByMinInt) {
+ auto* a = Scalar(i32::Lowest());
+ auto* b = Scalar(i32(-1));
+ auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), i32::Lowest());
+ EXPECT_EQ(error(), R"(warning: '-2147483648 % -1' cannot be represented as 'i32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_AInt_SignChange) {
+ auto* a = Scalar(AInt(0x0FFFFFFFFFFFFFFFll));
+ auto* b = Scalar(u32(9));
+ auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<AInt>(), static_cast<AInt>(0x0FFFFFFFFFFFFFFFull << 9));
+ EXPECT_EQ(error(), R"(warning: shift left operation results in sign change)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_I32_SignChange) {
+ auto* a = Scalar(i32(0x0FFFFFFF));
+ auto* b = Scalar(u32(9));
+ auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), static_cast<i32>(0x0FFFFFFFu << 9));
+ EXPECT_EQ(error(), R"(warning: shift left operation results in sign change)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_I32_MoreThanBitWidth) {
+ auto* a = Scalar(i32(0x1));
+ auto* b = Scalar(u32(33));
+ auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), 2);
+ EXPECT_EQ(
+ error(),
+ R"(warning: shift left value must be less than the bit width of the lhs, which is 32)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_U32_MoreThanBitWidth) {
+ auto* a = Scalar(u32(0x1));
+ auto* b = Scalar(u32(33));
+ auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<u32>(), 2);
+ EXPECT_EQ(
+ error(),
+ R"(warning: shift left value must be less than the bit width of the lhs, which is 32)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftRight_I32_MoreThanBitWidth) {
+ auto* a = Scalar(i32(0x2));
+ auto* b = Scalar(u32(33));
+ auto result = const_eval.OpShiftRight(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), 1);
+ EXPECT_EQ(
+ error(),
+ R"(warning: shift right value must be less than the bit width of the lhs, which is 32)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftRight_U32_MoreThanBitWidth) {
+ auto* a = Scalar(u32(0x2));
+ auto* b = Scalar(u32(33));
+ auto result = const_eval.OpShiftRight(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<u32>(), 1);
+ EXPECT_EQ(
+ error(),
+ R"(warning: shift right value must be less than the bit width of the lhs, which is 32)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Acos_F32_OutOfRange) {
+ auto* a = Scalar(f32(2));
+ auto result = const_eval.acos(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(),
+ R"(warning: acos must be called with a value in the range [-1 .. 1] (inclusive))");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Acosh_F32_OutOfRange) {
+ auto* a = Scalar(f32(-1));
+ auto result = const_eval.acosh(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: acosh must be called with a value >= 1.0)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Asin_F32_OutOfRange) {
+ auto* a = Scalar(f32(2));
+ auto result = const_eval.asin(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(),
+ R"(warning: asin must be called with a value in the range [-1 .. 1] (inclusive))");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Atanh_F32_OutOfRange) {
+ auto* a = Scalar(f32(2));
+ auto result = const_eval.atanh(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(),
+ R"(warning: atanh must be called with a value in the range (-1 .. 1) (exclusive))");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Exp_F32_Overflow) {
+ auto* a = Scalar(f32(1000));
+ auto result = const_eval.exp(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: e^1000 cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Exp2_F32_Overflow) {
+ auto* a = Scalar(f32(1000));
+ auto result = const_eval.exp2(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: 2^1000 cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, ExtractBits_I32_TooManyBits) {
+ auto* a = Scalar(i32(0x12345678));
+ auto* offset = Scalar(u32(24));
+ auto* count = Scalar(u32(16));
+ auto result = const_eval.extractBits(a->Type(), utils::Vector{a, offset, count}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), 0x12);
+ EXPECT_EQ(error(),
+ R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, ExtractBits_U32_TooManyBits) {
+ auto* a = Scalar(u32(0x12345678));
+ auto* offset = Scalar(u32(24));
+ auto* count = Scalar(u32(16));
+ auto result = const_eval.extractBits(a->Type(), utils::Vector{a, offset, count}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x12);
+ EXPECT_EQ(error(),
+ R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, InsertBits_I32_TooManyBits) {
+ auto* a = Scalar(i32(0x99345678));
+ auto* b = Scalar(i32(0x12));
+ auto* offset = Scalar(u32(24));
+ auto* count = Scalar(u32(16));
+ auto result = const_eval.insertBits(a->Type(), utils::Vector{a, b, offset, count}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<i32>(), 0x12345678);
+ EXPECT_EQ(error(),
+ R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, InsertBits_U32_TooManyBits) {
+ auto* a = Scalar(u32(0x99345678));
+ auto* b = Scalar(u32(0x12));
+ auto* offset = Scalar(u32(24));
+ auto* count = Scalar(u32(16));
+ auto result = const_eval.insertBits(a->Type(), utils::Vector{a, b, offset, count}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x12345678);
+ EXPECT_EQ(error(),
+ R"(warning: 'offset + 'count' must be less than or equal to the bit width of 'e')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, InverseSqrt_F32_OutOfRange) {
+ auto* a = Scalar(f32(-1));
+ auto result = const_eval.inverseSqrt(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: inverseSqrt must be called with a value > 0)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, LDExpr_F32_OutOfRange) {
+ auto* a = Scalar(f32(42.f));
+ auto* b = Scalar(f32(200));
+ auto result = const_eval.ldexp(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: e2 must be less than or equal to 128)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Log_F32_OutOfRange) {
+ auto* a = Scalar(f32(-1));
+ auto result = const_eval.log(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: log must be called with a value > 0)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Log2_F32_OutOfRange) {
+ auto* a = Scalar(f32(-1));
+ auto result = const_eval.log2(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: log2 must be called with a value > 0)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Normalize_ZeroLength) {
+ auto* zero = Scalar(f32(0));
+ auto* vec =
+ const_eval.VecSplat(create<type::Vector>(create<type::F32>(), 4u), utils::Vector{zero}, {})
+ .Get();
+ auto result = const_eval.normalize(vec->Type(), utils::Vector{vec}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->Index(0)->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(result.Get()->Index(1)->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(result.Get()->Index(2)->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(result.Get()->Index(3)->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: zero length vector can not be normalized)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Pack2x16Float_OutOfRange) {
+ auto* a = Scalar(f32(75250.f));
+ auto* b = Scalar(f32(42.1f));
+ auto* vec =
+ const_eval.VecInitS(create<type::Vector>(create<type::F32>(), 2u), utils::Vector{a, b}, {})
+ .Get();
+ auto result = const_eval.pack2x16float(create<type::U32>(), utils::Vector{vec}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<u32>(), 0x51430000);
+ EXPECT_EQ(error(), R"(warning: value 75250 cannot be represented as 'f16')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Pow_F32_Overflow) {
+ auto* a = Scalar(f32(2));
+ auto* b = Scalar(f32(1000));
+ auto result = const_eval.pow(a->Type(), utils::Vector{a, b}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: '2 ^ 1000' cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Unpack2x16Float_OutOfRange) {
+ auto* a = Scalar(u32(0x51437C00));
+ auto result = const_eval.unpack2x16float(create<type::U32>(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_FLOAT_EQ(result.Get()->Index(0)->ValueAs<f32>(), 0.f);
+ EXPECT_FLOAT_EQ(result.Get()->Index(1)->ValueAs<f32>(), 42.09375f);
+ EXPECT_EQ(error(), R"(warning: value inf cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, QuantizeToF16_OutOfRange) {
+ auto* a = Scalar(f32(75250.f));
+ auto result = const_eval.quantizeToF16(create<type::U32>(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<u32>(), 0);
+ EXPECT_EQ(error(), R"(warning: value 75250 cannot be represented as 'f16')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sqrt_F32_OutOfRange) {
+ auto* a = Scalar(f32(-1));
+ auto result = const_eval.sqrt(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: sqrt must be called with a value >= 0)");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Bitcast_Infinity) {
+ auto* a = Scalar(u32(0x7F800000));
+ auto result = const_eval.Bitcast(create<type::F32>(), a, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: value inf cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Bitcast_NaN) {
+ auto* a = Scalar(u32(0x7FC00000));
+ auto result = const_eval.Bitcast(create<type::F32>(), a, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), 0.f);
+ EXPECT_EQ(error(), R"(warning: value nan cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F32_TooHigh) {
+ auto* a = Scalar(AFloat::Highest());
+ auto result = const_eval.Convert(create<type::F32>(), a, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), f32::kHighestValue);
+ EXPECT_EQ(error(),
+ R"(warning: value 1.7976931348623157081e+308 cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F32_TooLow) {
+ auto* a = Scalar(AFloat::Lowest());
+ auto result = const_eval.Convert(create<type::F32>(), a, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), f32::kLowestValue);
+ EXPECT_EQ(error(),
+ R"(warning: value -1.7976931348623157081e+308 cannot be represented as 'f32')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F16_TooHigh) {
+ auto* a = Scalar(f32(1000000.0));
+ auto result = const_eval.Convert(create<type::F16>(), a, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), f16::kHighestValue);
+ EXPECT_EQ(error(), R"(warning: value 1000000 cannot be represented as 'f16')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F16_TooLow) {
+ auto* a = Scalar(f32(-1000000.0));
+ auto result = const_eval.Convert(create<type::F16>(), a, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->ValueAs<f32>(), f16::kLowestValue);
+ EXPECT_EQ(error(), R"(warning: value -1000000 cannot be represented as 'f16')");
+}
+
+TEST_F(ResolverConstEvalRuntimeSemanticsTest, Vec_Overflow_SingleComponent) {
+ // Test that overflow for an element-wise vector operation only affects a single component.
+ auto* vec4f = create<type::Vector>(create<type::F32>(), 4u);
+ auto* a = const_eval
+ .VecInitS(vec4f,
+ utils::Vector{
+ Scalar(f32(1)),
+ Scalar(f32(4)),
+ Scalar(f32(-1)),
+ Scalar(f32(65536)),
+ },
+ {})
+ .Get();
+ auto result = const_eval.sqrt(a->Type(), utils::Vector{a}, {});
+ ASSERT_TRUE(result);
+ EXPECT_EQ(result.Get()->Index(0)->ValueAs<f32>(), 1);
+ EXPECT_EQ(result.Get()->Index(1)->ValueAs<f32>(), 2);
+ EXPECT_EQ(result.Get()->Index(2)->ValueAs<f32>(), 0);
+ EXPECT_EQ(result.Get()->Index(3)->ValueAs<f32>(), 256);
+ EXPECT_EQ(error(), R"(warning: sqrt must be called with a value >= 0)");
+}
+
+} // namespace
+} // namespace tint::resolver
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 7dbd6ae..759ab5d 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -1971,7 +1971,7 @@
if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) {
stage = sem::EvaluationStage::kNotEvaluated;
} else {
- if (auto r = const_eval_.Index(obj, idx)) {
+ if (auto r = const_eval_.Index(ty, obj, idx)) {
val = r.Get();
} else {
return nullptr;