castable: Infer Is() TO type from predicate
For the Is() overload that takes a predicate function, infer the cast target type from the single parameter of the predicate.
Removes noise.
Change-Id: Ie6248c776ca1f9d50808e03e9685056fd3819217
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/62441
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/ast_type.cc b/src/ast/ast_type.cc
index 6dbf38b..c5d6871 100644
--- a/src/ast/ast_type.cc
+++ b/src/ast/ast_type.cc
@@ -58,13 +58,11 @@
}
bool Type::is_float_matrix() const {
- return Is<Matrix>(
- [](const Matrix* m) { return m->type()->is_float_scalar(); });
+ return Is([](const Matrix* m) { return m->type()->is_float_scalar(); });
}
bool Type::is_float_vector() const {
- return Is<Vector>(
- [](const Vector* v) { return v->type()->is_float_scalar(); });
+ return Is([](const Vector* v) { return v->type()->is_float_scalar(); });
}
bool Type::is_float_scalar_or_vector() const {
@@ -80,11 +78,11 @@
}
bool Type::is_unsigned_integer_vector() const {
- return Is<Vector>([](const Vector* v) { return v->type()->Is<U32>(); });
+ return Is([](const Vector* v) { return v->type()->Is<U32>(); });
}
bool Type::is_signed_integer_vector() const {
- return Is<Vector>([](const Vector* v) { return v->type()->Is<I32>(); });
+ return Is([](const Vector* v) { return v->type()->Is<I32>(); });
}
bool Type::is_unsigned_scalar_or_vector() const {
@@ -100,7 +98,7 @@
}
bool Type::is_bool_vector() const {
- return Is<Vector>([](const Vector* v) { return v->type()->Is<Bool>(); });
+ return Is([](const Vector* v) { return v->type()->Is<Bool>(); });
}
bool Type::is_bool_scalar_or_vector() const {
diff --git a/src/castable.h b/src/castable.h
index d825de1..ed075b0 100644
--- a/src/castable.h
+++ b/src/castable.h
@@ -17,6 +17,8 @@
#include <utility>
+#include "src/traits.h"
+
#if defined(__clang__)
/// Temporarily disable certain warnings when using Castable API
#define TINT_CASTABLE_PUSH_DISABLE_WARNINGS() \
@@ -298,8 +300,9 @@
/// pred(const TO*) returns true
/// @param pred predicate function with signature `bool(const TO*)` called iff
/// object is of, or derives from the class `TO`.
- template <typename TO, int FLAGS = 0, typename Pred = detail::Infer>
+ template <int FLAGS = 0, typename Pred = detail::Infer>
inline bool Is(Pred&& pred) const {
+ using TO = typename std::remove_pointer<traits::ParamTypeT<Pred, 0>>::type;
return tint::Is<TO, FLAGS>(static_cast<const CLASS*>(this),
std::forward<Pred>(pred));
}
diff --git a/src/castable_test.cc b/src/castable_test.cc
index 0d10aac..e44983b 100644
--- a/src/castable_test.cc
+++ b/src/castable_test.cc
@@ -214,17 +214,17 @@
TEST(Castable, IsWithPredicate) {
std::unique_ptr<Animal> frog = std::make_unique<Frog>();
- frog->Is<Animal>([&frog](const Animal* a) {
+ frog->Is([&frog](const Animal* a) {
EXPECT_EQ(a, frog.get());
return true;
});
- ASSERT_TRUE((frog->Is<Animal>([](const Animal*) { return true; })));
- ASSERT_FALSE((frog->Is<Animal>([](const Animal*) { return false; })));
+ ASSERT_TRUE((frog->Is([](const Animal*) { return true; })));
+ ASSERT_FALSE((frog->Is([](const Animal*) { return false; })));
// Predicate not called if cast is invalid
auto expect_not_called = [] { FAIL() << "Should not be called"; };
- ASSERT_FALSE((frog->Is<Bear>([&](const Animal*) {
+ ASSERT_FALSE((frog->Is([&](const Bear*) {
expect_not_called();
return true;
})));
diff --git a/src/intrinsic_table.cc b/src/intrinsic_table.cc
index 970a9dd..0f497ef 100644
--- a/src/intrinsic_table.cc
+++ b/src/intrinsic_table.cc
@@ -483,7 +483,7 @@
if (ty->Is<Any>()) {
return true;
}
- return ty->Is<sem::Sampler>([](const sem::Sampler* s) {
+ return ty->Is([](const sem::Sampler* s) {
return s->kind() == ast::SamplerKind::kSampler;
});
}
@@ -496,7 +496,7 @@
if (ty->Is<Any>()) {
return true;
}
- return ty->Is<sem::Sampler>([](const sem::Sampler* s) {
+ return ty->Is([](const sem::Sampler* s) {
return s->kind() == ast::SamplerKind::kComparisonSampler;
});
}
@@ -575,7 +575,7 @@
if (ty->Is<Any>()) {
return true;
}
- return ty->Is<sem::DepthTexture>([&](auto t) { return t->dim() == dim; });
+ return ty->Is([&](const sem::DepthTexture* t) { return t->dim() == dim; });
}
#define DECLARE_DEPTH_TEXTURE(suffix, dim) \
@@ -597,8 +597,9 @@
if (ty->Is<Any>()) {
return true;
}
- return ty->Is<sem::DepthMultisampledTexture>(
- [&](auto t) { return t->dim() == ast::TextureDimension::k2d; });
+ return ty->Is([&](const sem::DepthMultisampledTexture* t) {
+ return t->dim() == ast::TextureDimension::k2d;
+ });
}
sem::DepthMultisampledTexture* build_texture_depth_multisampled_2d(
diff --git a/src/reader/spirv/parser_type.cc b/src/reader/spirv/parser_type.cc
index 9b3e527..7802bad 100644
--- a/src/reader/spirv/parser_type.cc
+++ b/src/reader/spirv/parser_type.cc
@@ -367,7 +367,7 @@
}
bool Type::IsFloatVector() const {
- return Is<Vector>([](const Vector* v) { return v->type->IsFloatScalar(); });
+ return Is([](const Vector* v) { return v->type->IsFloatScalar(); });
}
bool Type::IsIntegerScalar() const {
@@ -383,7 +383,7 @@
}
bool Type::IsSignedIntegerVector() const {
- return Is<Vector>([](const Vector* v) { return v->type->Is<I32>(); });
+ return Is([](const Vector* v) { return v->type->Is<I32>(); });
}
bool Type::IsSignedScalarOrVector() const {
@@ -391,7 +391,7 @@
}
bool Type::IsUnsignedIntegerVector() const {
- return Is<Vector>([](const Vector* v) { return v->type->Is<U32>(); });
+ return Is([](const Vector* v) { return v->type->Is<U32>(); });
}
bool Type::IsUnsignedScalarOrVector() const {
diff --git a/src/sem/type.cc b/src/sem/type.cc
index 5f547a2..158fb73 100644
--- a/src/sem/type.cc
+++ b/src/sem/type.cc
@@ -77,13 +77,11 @@
}
bool Type::is_float_matrix() const {
- return Is<Matrix>(
- [](const Matrix* m) { return m->type()->is_float_scalar(); });
+ return Is([](const Matrix* m) { return m->type()->is_float_scalar(); });
}
bool Type::is_float_vector() const {
- return Is<Vector>(
- [](const Vector* v) { return v->type()->is_float_scalar(); });
+ return Is([](const Vector* v) { return v->type()->is_float_scalar(); });
}
bool Type::is_float_scalar_or_vector() const {
@@ -107,11 +105,11 @@
}
bool Type::is_signed_integer_vector() const {
- return Is<Vector>([](const Vector* v) { return v->type()->Is<I32>(); });
+ return Is([](const Vector* v) { return v->type()->Is<I32>(); });
}
bool Type::is_unsigned_integer_vector() const {
- return Is<Vector>([](const Vector* v) { return v->type()->Is<U32>(); });
+ return Is([](const Vector* v) { return v->type()->Is<U32>(); });
}
bool Type::is_unsigned_scalar_or_vector() const {
@@ -127,7 +125,7 @@
}
bool Type::is_bool_vector() const {
- return Is<Vector>([](const Vector* v) { return v->type()->Is<Bool>(); });
+ return Is([](const Vector* v) { return v->type()->Is<Bool>(); });
}
bool Type::is_bool_scalar_or_vector() const {
@@ -135,12 +133,11 @@
}
bool Type::is_numeric_vector() const {
- return Is<Vector>(
- [](const Vector* v) { return v->type()->is_numeric_scalar(); });
+ return Is([](const Vector* v) { return v->type()->is_numeric_scalar(); });
}
bool Type::is_scalar_vector() const {
- return Is<Vector>([](const Vector* v) { return v->type()->is_scalar(); });
+ return Is([](const Vector* v) { return v->type()->is_scalar(); });
}
bool Type::is_numeric_scalar_or_vector() const {