tint: const eval of comparison operations
Change-Id: Iec6e78dbe00baaed8c90e709447a20f6c8ac9fb0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/101304
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def
index 8a80bf5..39e0b14 100644
--- a/src/tint/intrinsics.def
+++ b/src/tint/intrinsics.def
@@ -934,23 +934,23 @@
op && (bool, bool) -> bool
op || (bool, bool) -> bool
-op == <T: scalar>(T, T) -> bool
-op == <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
+@const op == <T: abstract_or_scalar>(T, T) -> bool
+@const op == <T: abstract_or_scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
-op != <T: scalar>(T, T) -> bool
-op != <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
+@const op != <T: abstract_or_scalar>(T, T) -> bool
+@const op != <T: abstract_or_scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
-op < <T: fiu32_f16>(T, T) -> bool
-op < <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
+@const op < <T: fia_fiu32_f16>(T, T) -> bool
+@const op < <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
-op > <T: fiu32_f16>(T, T) -> bool
-op > <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
+@const op > <T: fia_fiu32_f16>(T, T) -> bool
+@const op > <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
-op <= <T: fiu32_f16>(T, T) -> bool
-op <= <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
+@const op <= <T: fia_fiu32_f16>(T, T) -> bool
+@const op <= <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
-op >= <T: fiu32_f16>(T, T) -> bool
-op >= <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
+@const op >= <T: fia_fiu32_f16>(T, T) -> bool
+@const op >= <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op << <T: iu32>(T, u32) -> T
op << <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
diff --git a/src/tint/number.h b/src/tint/number.h
index 29ae227..6e032d7 100644
--- a/src/tint/number.h
+++ b/src/tint/number.h
@@ -260,7 +260,7 @@
using f16 = Number<detail::NumberKindF16>;
/// @returns the friendly name of Number type T
-template <typename T, typename = traits::EnableIf<IsNumber<T>>>
+template <typename T, traits::EnableIf<IsNumber<T>>* = nullptr>
const char* FriendlyName() {
if constexpr (std::is_same_v<T, AInt>) {
return "abstract-int";
@@ -279,6 +279,12 @@
}
}
+/// @returns the friendly name of T when T is bool
+template <typename T, traits::EnableIf<std::is_same_v<T, bool>>* = nullptr>
+const char* FriendlyName() {
+ return "bool";
+}
+
/// Enumerator of failure reasons when converting from one number to another.
enum class ConversionFailure {
kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc
index 0f4872b..991fe79 100644
--- a/src/tint/resolver/const_eval.cc
+++ b/src/tint/resolver/const_eval.cc
@@ -94,6 +94,21 @@
/// Helper that calls `f` passing in the value of all `cs`.
/// Assumes all `cs` are of the same type.
template <typename F, typename... CONSTANTS>
+auto Dispatch_fia_fiu32_f16_bool(F&& f, CONSTANTS&&... cs) {
+ return Switch(
+ First(cs...)->Type(), //
+ [&](const sem::AbstractInt*) { return f(cs->template As<AInt>()...); },
+ [&](const sem::AbstractFloat*) { return f(cs->template As<AFloat>()...); },
+ [&](const sem::F32*) { return f(cs->template As<f32>()...); },
+ [&](const sem::I32*) { return f(cs->template As<i32>()...); },
+ [&](const sem::U32*) { return f(cs->template As<u32>()...); },
+ [&](const sem::F16*) { return f(cs->template As<f16>()...); },
+ [&](const sem::Bool*) { return f(cs->template As<bool>()...); });
+}
+
+/// Helper that calls `f` passing in the value of all `cs`.
+/// Assumes all `cs` are of the same type.
+template <typename F, typename... CONSTANTS>
auto Dispatch_fa_f32_f16(F&& f, CONSTANTS&&... cs) {
return Switch(
First(cs...)->Type(), //
@@ -466,10 +481,14 @@
}
}
-/// TransformElements constructs a new constant by applying the transformation function 'f' on each
-/// of the most deeply nested elements of 'cs'. Assumes that all constants are the same type.
+/// TransformElements constructs a new constant of type `composite_ty` by applying the
+/// transformation function 'f' on each of the most deeply nested elements of 'cs'. Assumes that all
+/// input constants `cs` are of the same type.
template <typename F, typename... CONSTANTS>
-const Constant* TransformElements(ProgramBuilder& builder, F&& f, CONSTANTS&&... cs) {
+const Constant* TransformElements(ProgramBuilder& builder,
+ const sem::Type* composite_ty,
+ F&& f,
+ CONSTANTS&&... cs) {
uint32_t n = 0;
auto* ty = First(cs...)->Type();
auto* el_ty = sem::Type::ElementOf(ty, &n);
@@ -479,16 +498,19 @@
utils::Vector<const sem::Constant*, 8> els;
els.Reserve(n);
for (uint32_t i = 0; i < n; i++) {
- els.Push(TransformElements(builder, std::forward<F>(f), cs->Index(i)...));
+ els.Push(TransformElements(builder, sem::Type::ElementOf(composite_ty), std::forward<F>(f),
+ cs->Index(i)...));
}
- return CreateComposite(builder, ty, std::move(els));
+ return CreateComposite(builder, composite_ty, std::move(els));
}
-/// TransformBinaryElements constructs a new constant by applying the transformation function 'f' on
-/// each of the most deeply nested elements of both `c0` and `c1`. Unlike TransformElements, this
-/// function handles the constants being of different types, e.g. vector-scalar, scalar-vector.
+/// TransformBinaryElements constructs a new constant of type `composite_ty` by applying the
+/// transformation function 'f' on each of the most deeply nested elements of both `c0` and `c1`.
+/// Unlike TransformElements, this function handles the constants being of different types, e.g.
+/// vector-scalar, scalar-vector.
template <typename F>
const Constant* TransformBinaryElements(ProgramBuilder& builder,
+ const sem::Type* composite_ty,
F&& f,
const sem::Constant* c0,
const sem::Constant* c1) {
@@ -510,12 +532,11 @@
}
return c->Index(i);
};
- els.Push(TransformBinaryElements(builder, std::forward<F>(f), nested_or_self(c0, n0),
+ els.Push(TransformBinaryElements(builder, sem::Type::ElementOf(composite_ty),
+ std::forward<F>(f), nested_or_self(c0, n0),
nested_or_self(c1, n1)));
}
- // Use larger type
- auto* ty = n0 > n1 ? c0->Type() : c1->Type();
- return CreateComposite(builder, ty, std::move(els));
+ return CreateComposite(builder, composite_ty, std::move(els));
}
} // namespace
@@ -915,7 +936,7 @@
return nullptr;
}
-ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type*,
+ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c) {
@@ -924,10 +945,10 @@
};
return Dispatch_ia_iu32(create, c);
};
- return TransformElements(builder, transform, args[0]);
+ return TransformElements(builder, ty, transform, args[0]);
}
-ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*,
+ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c) {
@@ -949,10 +970,10 @@
};
return Dispatch_fia_fi32_f16(create, c);
};
- return TransformElements(builder, transform, args[0]);
+ return TransformElements(builder, ty, transform, args[0]);
}
-ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type*,
+ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
@@ -963,14 +984,14 @@
return nullptr;
};
- auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
+ auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
-ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*,
+ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
@@ -1003,14 +1024,14 @@
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
- auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
+ auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
-ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* /*ty*/,
+ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
@@ -1021,7 +1042,7 @@
return nullptr;
};
- auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
+ auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
@@ -1196,7 +1217,7 @@
return CreateComposite(builder, ty, result_mat);
}
-ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type*,
+ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
@@ -1237,14 +1258,116 @@
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
- auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
+ auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
-ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
+ConstEval::ConstantResult ConstEval::OpEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto i, auto j) -> const Constant* {
+ return CreateElement(builder, sem::Type::DeepestElementOf(ty), i == j);
+ };
+ return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
+ };
+
+ auto r = TransformElements(builder, ty, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+ConstEval::ConstantResult ConstEval::OpNotEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto i, auto j) -> const Constant* {
+ return CreateElement(builder, sem::Type::DeepestElementOf(ty), i != j);
+ };
+ return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
+ };
+
+ auto r = TransformElements(builder, ty, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+ConstEval::ConstantResult ConstEval::OpLessThan(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto i, auto j) -> const Constant* {
+ return CreateElement(builder, sem::Type::DeepestElementOf(ty), i < j);
+ };
+ return Dispatch_fia_fiu32_f16(create, c0, c1);
+ };
+
+ auto r = TransformElements(builder, ty, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+ConstEval::ConstantResult ConstEval::OpGreaterThan(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto i, auto j) -> const Constant* {
+ return CreateElement(builder, sem::Type::DeepestElementOf(ty), i > j);
+ };
+ return Dispatch_fia_fiu32_f16(create, c0, c1);
+ };
+
+ auto r = TransformElements(builder, ty, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+ConstEval::ConstantResult ConstEval::OpLessThanEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto i, auto j) -> const Constant* {
+ return CreateElement(builder, sem::Type::DeepestElementOf(ty), i <= j);
+ };
+ return Dispatch_fia_fiu32_f16(create, c0, c1);
+ };
+
+ auto r = TransformElements(builder, ty, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+ConstEval::ConstantResult ConstEval::OpGreaterThanEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source&) {
+ auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
+ auto create = [&](auto i, auto j) -> const Constant* {
+ return CreateElement(builder, sem::Type::DeepestElementOf(ty), i >= j);
+ };
+ return Dispatch_fia_fiu32_f16(create, c0, c1);
+ };
+
+ auto r = TransformElements(builder, ty, transform, args[0], args[1]);
+ if (builder.Diagnostics().contains_errors()) {
+ return utils::Failure;
+ }
+ return r;
+}
+
+ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
@@ -1253,10 +1376,10 @@
};
return Dispatch_fa_f32_f16(create, c0, c1);
};
- return TransformElements(builder, transform, args[0], args[1]);
+ return TransformElements(builder, ty, transform, args[0], args[1]);
}
-ConstEval::ConstantResult ConstEval::clamp(const sem::Type*,
+ConstEval::ConstantResult ConstEval::clamp(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
@@ -1267,7 +1390,7 @@
};
return Dispatch_fia_fiu32_f16(create, c0, c1, c2);
};
- return TransformElements(builder, transform, args[0], args[1], args[2]);
+ return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
}
utils::Result<const sem::Constant*> ConstEval::Convert(const sem::Type* target_ty,
diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h
index f84e28c..df98d58 100644
--- a/src/tint/resolver/const_eval.h
+++ b/src/tint/resolver/const_eval.h
@@ -275,6 +275,60 @@
utils::VectorRef<const sem::Constant*> args,
const Source& source);
+ /// Equality operator '=='
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
+ /// Inequality operator '!='
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpNotEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
+ /// Less than operator '<'
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpLessThan(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
+ /// Greater than operator '>'
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpGreaterThan(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
+ /// Less than or equal operator '<='
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpLessThanEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
+ /// Greater than or equal operator '>='
+ /// @param ty the expression type
+ /// @param args the input arguments
+ /// @param source the source location of the conversion
+ /// @return the result value, or null if the value cannot be calculated
+ ConstantResult OpGreaterThanEqual(const sem::Type* ty,
+ utils::VectorRef<const sem::Constant*> args,
+ const Source& source);
+
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc
index 1f745a3..8b48e8f 100644
--- a/src/tint/resolver/const_eval_test.cc
+++ b/src/tint/resolver/const_eval_test.cc
@@ -3202,6 +3202,7 @@
Value<i32>,
Value<f32>,
Value<f16>,
+ Value<bool>,
Value<builder::vec2<AInt>>,
Value<builder::vec2<AFloat>>,
@@ -3209,6 +3210,7 @@
Value<builder::vec2<i32>>,
Value<builder::vec2<f32>>,
Value<builder::vec2<f16>>,
+ Value<builder::vec2<bool>>,
Value<builder::vec3<AInt>>,
Value<builder::vec3<AFloat>>,
@@ -3584,6 +3586,115 @@
OpDivFloatCases<f32>(),
OpDivFloatCases<f16>()))));
+template <typename T, bool equals>
+std::vector<Case> OpEqualCases() {
+ return {
+ C(Val(T{0}), Val(T{0}), Val(true == equals)),
+ C(Val(T{0}), Val(T{1}), Val(false == equals)),
+ C(Val(T{1}), Val(T{0}), Val(false == equals)),
+ C(Val(T{1}), Val(T{1}), Val(true == equals)),
+ C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(true == equals, true == equals)),
+ C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(false == equals, false == equals)),
+ C(Vec(T{1}, T{1}), Vec(T{0}, T{1}), Vec(false == equals, true == equals)),
+ };
+}
+INSTANTIATE_TEST_SUITE_P(Equal,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kEqual),
+ testing::ValuesIn(Concat( //
+ OpEqualCases<AInt, true>(),
+ OpEqualCases<i32, true>(),
+ OpEqualCases<u32, true>(),
+ OpEqualCases<AFloat, true>(),
+ OpEqualCases<f32, true>(),
+ OpEqualCases<f16, true>(),
+ OpEqualCases<bool, true>()))));
+INSTANTIATE_TEST_SUITE_P(NotEqual,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kNotEqual),
+ testing::ValuesIn(Concat( //
+ OpEqualCases<AInt, false>(),
+ OpEqualCases<i32, false>(),
+ OpEqualCases<u32, false>(),
+ OpEqualCases<AFloat, false>(),
+ OpEqualCases<f32, false>(),
+ OpEqualCases<f16, false>(),
+ OpEqualCases<bool, false>()))));
+
+template <typename T, bool less_than>
+std::vector<Case> OpLessThanCases() {
+ return {
+ C(Val(T{0}), Val(T{0}), Val(false == less_than)),
+ C(Val(T{0}), Val(T{1}), Val(true == less_than)),
+ C(Val(T{1}), Val(T{0}), Val(false == less_than)),
+ C(Val(T{1}), Val(T{1}), Val(false == less_than)),
+ C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(false == less_than, false == less_than)),
+ C(Vec(T{0}, T{0}), Vec(T{1}, T{1}), Vec(true == less_than, true == less_than)),
+ C(Vec(T{1}, T{1}), Vec(T{0}, T{0}), Vec(false == less_than, false == less_than)),
+ C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(false == less_than, true == less_than)),
+ };
+}
+INSTANTIATE_TEST_SUITE_P(LessThan,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kLessThan),
+ testing::ValuesIn(Concat( //
+ OpLessThanCases<AInt, true>(),
+ OpLessThanCases<i32, true>(),
+ OpLessThanCases<u32, true>(),
+ OpLessThanCases<AFloat, true>(),
+ OpLessThanCases<f32, true>(),
+ OpLessThanCases<f16, true>()))));
+INSTANTIATE_TEST_SUITE_P(GreaterThanEqual,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kGreaterThanEqual),
+ testing::ValuesIn(Concat( //
+ OpLessThanCases<AInt, false>(),
+ OpLessThanCases<i32, false>(),
+ OpLessThanCases<u32, false>(),
+ OpLessThanCases<AFloat, false>(),
+ OpLessThanCases<f32, false>(),
+ OpLessThanCases<f16, false>()))));
+
+template <typename T, bool greater_than>
+std::vector<Case> OpGreaterThanCases() {
+ return {
+ C(Val(T{0}), Val(T{0}), Val(false == greater_than)),
+ C(Val(T{0}), Val(T{1}), Val(false == greater_than)),
+ C(Val(T{1}), Val(T{0}), Val(true == greater_than)),
+ C(Val(T{1}), Val(T{1}), Val(false == greater_than)),
+ C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(false == greater_than, false == greater_than)),
+ C(Vec(T{1}, T{1}), Vec(T{0}, T{0}), Vec(true == greater_than, true == greater_than)),
+ C(Vec(T{0}, T{0}), Vec(T{1}, T{1}), Vec(false == greater_than, false == greater_than)),
+ C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(true == greater_than, false == greater_than)),
+ };
+}
+INSTANTIATE_TEST_SUITE_P(GreaterThan,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kGreaterThan),
+ testing::ValuesIn(Concat( //
+ OpGreaterThanCases<AInt, true>(),
+ OpGreaterThanCases<i32, true>(),
+ OpGreaterThanCases<u32, true>(),
+ OpGreaterThanCases<AFloat, true>(),
+ OpGreaterThanCases<f32, true>(),
+ OpGreaterThanCases<f16, true>()))));
+INSTANTIATE_TEST_SUITE_P(LessThanEqual,
+ ResolverConstEvalBinaryOpTest,
+ testing::Combine( //
+ testing::Values(ast::BinaryOp::kLessThanEqual),
+ testing::ValuesIn(Concat( //
+ OpGreaterThanCases<AInt, false>(),
+ OpGreaterThanCases<i32, false>(),
+ OpGreaterThanCases<u32, false>(),
+ OpGreaterThanCases<AFloat, false>(),
+ OpGreaterThanCases<f32, false>(),
+ OpGreaterThanCases<f16, false>()))));
+
// Tests for errors on overflow/underflow of binary operations with abstract numbers
struct OverflowCase {
ast::BinaryOp op;
@@ -3608,7 +3719,7 @@
std::string type_name = std::visit(
[&](auto&& value) {
using ValueType = std::decay_t<decltype(value)>;
- return tint::FriendlyName<typename ValueType::ElementType>();
+ return builder::FriendlyName<ValueType>();
},
c.lhs);
diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl
index 675aa41..b1ac41e 100644
--- a/src/tint/resolver/intrinsic_table.inl
+++ b/src/tint/resolver/intrinsic_table.inl
@@ -12584,12 +12584,12 @@
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[635],
/* return matcher indices */ &kMatcherIndices[16],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpGreaterThanEqual,
},
{
/* [369] */
@@ -12601,55 +12601,55 @@
/* parameters */ &kParameters[631],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpGreaterThanEqual,
},
{
/* [370] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[641],
/* return matcher indices */ &kMatcherIndices[16],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpLessThanEqual,
},
{
/* [371] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[639],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpLessThanEqual,
},
{
/* [372] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[649],
/* return matcher indices */ &kMatcherIndices[16],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpGreaterThan,
},
{
/* [373] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[645],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpGreaterThan,
},
{
/* [374] */
@@ -12824,24 +12824,24 @@
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[657],
/* return matcher indices */ &kMatcherIndices[16],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpLessThan,
},
{
/* [389] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[15],
+ /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[653],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpLessThan,
},
{
/* [390] */
@@ -12992,48 +12992,48 @@
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[16],
+ /* template types */ &kTemplateTypes[18],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[659],
/* return matcher indices */ &kMatcherIndices[16],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpNotEqual,
},
{
/* [403] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[16],
+ /* template types */ &kTemplateTypes[18],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[599],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpNotEqual,
},
{
/* [404] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
- /* template types */ &kTemplateTypes[16],
+ /* template types */ &kTemplateTypes[18],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[663],
/* return matcher indices */ &kMatcherIndices[16],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpEqual,
},
{
/* [405] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
- /* template types */ &kTemplateTypes[16],
+ /* template types */ &kTemplateTypes[18],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[661],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
- /* const eval */ nullptr,
+ /* const eval */ &ConstEval::OpEqual,
},
{
/* [406] */
@@ -14699,42 +14699,42 @@
},
{
/* [10] */
- /* op ==<T : scalar>(T, T) -> bool */
- /* op ==<T : scalar, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
+ /* op ==<T : abstract_or_scalar>(T, T) -> bool */
+ /* op ==<T : abstract_or_scalar, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
/* num overloads */ 2,
/* overloads */ &kOverloads[404],
},
{
/* [11] */
- /* op !=<T : scalar>(T, T) -> bool */
- /* op !=<T : scalar, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
+ /* op !=<T : abstract_or_scalar>(T, T) -> bool */
+ /* op !=<T : abstract_or_scalar, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
/* num overloads */ 2,
/* overloads */ &kOverloads[402],
},
{
/* [12] */
- /* op <<T : fiu32_f16>(T, T) -> bool */
- /* op <<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
+ /* op <<T : fia_fiu32_f16>(T, T) -> bool */
+ /* op <<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
/* num overloads */ 2,
/* overloads */ &kOverloads[388],
},
{
/* [13] */
- /* op ><T : fiu32_f16>(T, T) -> bool */
- /* op ><T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
+ /* op ><T : fia_fiu32_f16>(T, T) -> bool */
+ /* op ><T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
/* num overloads */ 2,
/* overloads */ &kOverloads[372],
},
{
/* [14] */
- /* op <=<T : fiu32_f16>(T, T) -> bool */
- /* op <=<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
+ /* op <=<T : fia_fiu32_f16>(T, T) -> bool */
+ /* op <=<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
/* num overloads */ 2,
/* overloads */ &kOverloads[370],
},
{
/* [15] */
- /* op >=<T : fiu32_f16>(T, T) -> bool */
+ /* op >=<T : fia_fiu32_f16>(T, T) -> bool */
/* op >=<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
/* num overloads */ 2,
/* overloads */ &kOverloads[368],
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 1eb3b35..fd0d821 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -2510,6 +2510,7 @@
if (!op.result) {
return nullptr;
}
+ ty = op.result;
if (ShouldMaterializeArgument(op.parameter)) {
expr = Materialize(expr, op.parameter);
if (!expr) {
@@ -2530,7 +2531,6 @@
stage = sem::EvaluationStage::kRuntime;
}
}
- ty = op.result;
break;
}
}
diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h
index 6641176..77d1481 100644
--- a/src/tint/resolver/resolver_test_helper.h
+++ b/src/tint/resolver/resolver_test_helper.h
@@ -751,6 +751,12 @@
template <typename T>
constexpr bool IsValue = detail::IsValue<T>::value;
+/// Returns the friendly name of ValueT
+template <typename ValueT, typename = traits::EnableIf<IsValue<ValueT>>>
+const char* FriendlyName() {
+ return tint::FriendlyName<typename ValueT::ElementType>();
+}
+
/// Creates a `Value<T>` from a scalar `v`
template <typename T>
auto Val(T v) {
diff --git a/src/tint/writer/spirv/builder_binary_expression_test.cc b/src/tint/writer/spirv/builder_binary_expression_test.cc
index 9c4ba08..fd9a64a 100644
--- a/src/tint/writer/spirv/builder_binary_expression_test.cc
+++ b/src/tint/writer/spirv/builder_binary_expression_test.cc
@@ -981,36 +981,52 @@
}
TEST_F(BuilderTest, Binary_LogicalAnd) {
- auto* lhs = create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(1_i), Expr(2_i));
- auto* rhs = create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(3_i), Expr(4_i));
- auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, lhs, rhs);
+ auto* v0 = Var("a", Expr(1_i));
+ auto* v1 = Var("b", Expr(2_i));
+ auto* v2 = Var("c", Expr(3_i));
+ auto* v3 = Var("d", Expr(4_i));
+ auto* expr = LogicalAnd(Equal("a", "b"), Equal("c", "d"));
- WrapInFunction(expr);
+ WrapInFunction(v0, v1, v2, v3, expr);
spirv::Builder& b = Build();
b.push_function(Function{});
b.GenerateLabel(b.next_id());
+ ASSERT_TRUE(b.GenerateFunctionVariable(v0)) << b.error();
+ ASSERT_TRUE(b.GenerateFunctionVariable(v1)) << b.error();
+ ASSERT_TRUE(b.GenerateFunctionVariable(v2)) << b.error();
+ ASSERT_TRUE(b.GenerateFunctionVariable(v3)) << b.error();
- EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(expr), 22u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%2 = OpTypeInt 32 1
%3 = OpConstant %2 1
-%4 = OpConstant %2 2
-%6 = OpTypeBool
+%5 = OpTypePointer Function %2
+%6 = OpConstantNull %2
+%7 = OpConstant %2 2
%9 = OpConstant %2 3
-%10 = OpConstant %2 4
+%11 = OpConstant %2 4
+%16 = OpTypeBool
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%1 = OpLabel
-%5 = OpIEqual %6 %3 %4
-OpSelectionMerge %7 None
-OpBranchConditional %5 %8 %7
-%8 = OpLabel
-%11 = OpIEqual %6 %9 %10
-OpBranch %7
-%7 = OpLabel
-%12 = OpPhi %6 %5 %1 %11 %8
+OpStore %4 %3
+OpStore %8 %7
+OpStore %10 %9
+OpStore %12 %11
+%13 = OpLoad %2 %4
+%14 = OpLoad %2 %8
+%15 = OpIEqual %16 %13 %14
+OpSelectionMerge %17 None
+OpBranchConditional %15 %18 %17
+%18 = OpLabel
+%19 = OpLoad %2 %10
+%20 = OpLoad %2 %12
+%21 = OpIEqual %16 %19 %20
+OpBranch %17
+%17 = OpLabel
+%22 = OpPhi %16 %15 %1 %21 %18
)");
}
@@ -1131,38 +1147,52 @@
}
TEST_F(BuilderTest, Binary_LogicalOr) {
- auto* lhs = create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(1_i), Expr(2_i));
+ auto* v0 = Var("a", Expr(1_i));
+ auto* v1 = Var("b", Expr(2_i));
+ auto* v2 = Var("c", Expr(3_i));
+ auto* v3 = Var("d", Expr(4_i));
+ auto* expr = LogicalOr(Equal("a", "b"), Equal("c", "d"));
- auto* rhs = create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(3_i), Expr(4_i));
-
- auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, lhs, rhs);
-
- WrapInFunction(expr);
+ WrapInFunction(v0, v1, v2, v3, expr);
spirv::Builder& b = Build();
b.push_function(Function{});
b.GenerateLabel(b.next_id());
+ ASSERT_TRUE(b.GenerateFunctionVariable(v0)) << b.error();
+ ASSERT_TRUE(b.GenerateFunctionVariable(v1)) << b.error();
+ ASSERT_TRUE(b.GenerateFunctionVariable(v2)) << b.error();
+ ASSERT_TRUE(b.GenerateFunctionVariable(v3)) << b.error();
- EXPECT_EQ(b.GenerateBinaryExpression(expr), 12u) << b.error();
+ EXPECT_EQ(b.GenerateBinaryExpression(expr), 22u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%2 = OpTypeInt 32 1
%3 = OpConstant %2 1
-%4 = OpConstant %2 2
-%6 = OpTypeBool
+%5 = OpTypePointer Function %2
+%6 = OpConstantNull %2
+%7 = OpConstant %2 2
%9 = OpConstant %2 3
-%10 = OpConstant %2 4
+%11 = OpConstant %2 4
+%16 = OpTypeBool
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%1 = OpLabel
-%5 = OpIEqual %6 %3 %4
-OpSelectionMerge %7 None
-OpBranchConditional %5 %7 %8
-%8 = OpLabel
-%11 = OpIEqual %6 %9 %10
-OpBranch %7
-%7 = OpLabel
-%12 = OpPhi %6 %5 %1 %11 %8
+OpStore %4 %3
+OpStore %8 %7
+OpStore %10 %9
+OpStore %12 %11
+%13 = OpLoad %2 %4
+%14 = OpLoad %2 %8
+%15 = OpIEqual %16 %13 %14
+OpSelectionMerge %17 None
+OpBranchConditional %15 %17 %18
+%18 = OpLabel
+%19 = OpLoad %2 %10
+%20 = OpLoad %2 %12
+%21 = OpIEqual %16 %19 %20
+OpBranch %17
+%17 = OpLabel
+%22 = OpPhi %16 %15 %1 %21 %18
)");
}