tint/sem: Add Type::HoldsAbstract()
Tells you if there's an abstract numeric somewhere in the type.
Change-Id: I0573be9e57ec48f2fa63c46944214e7f5be7d67c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/104823
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc
index 8ace823..fed75e6 100644
--- a/src/tint/sem/type.cc
+++ b/src/tint/sem/type.cc
@@ -25,6 +25,7 @@
#include "src/tint/sem/pointer.h"
#include "src/tint/sem/reference.h"
#include "src/tint/sem/sampler.h"
+#include "src/tint/sem/struct.h"
#include "src/tint/sem/texture.h"
#include "src/tint/sem/u32.h"
#include "src/tint/sem/vector.h"
@@ -172,6 +173,23 @@
return IsAnyOf<Sampler, Texture>();
}
+bool Type::HoldsAbstract() const {
+ return Switch(
+ this, //
+ [&](const AbstractNumeric*) { return true; },
+ [&](const Vector* v) { return v->type()->HoldsAbstract(); },
+ [&](const Matrix* m) { return m->type()->HoldsAbstract(); },
+ [&](const Array* a) { return a->ElemType()->HoldsAbstract(); },
+ [&](const Struct* s) {
+ for (auto* m : s->Members()) {
+ if (m->Type()->HoldsAbstract()) {
+ return true;
+ }
+ }
+ return false;
+ });
+}
+
uint32_t Type::ConversionRank(const Type* from, const Type* to) {
if (from->UnwrapRef() == to) {
return 0;
diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h
index 434b913..ffa3d01 100644
--- a/src/tint/sem/type.h
+++ b/src/tint/sem/type.h
@@ -123,6 +123,10 @@
/// @returns true if this type is a handle type
bool is_handle() const;
+ /// @returns true if this type is an abstract-numeric or if the type holds an element that is an
+ /// abstract-numeric.
+ bool HoldsAbstract() const;
+
/// kNoConversion is returned from ConversionRank() when the implicit conversion is not
/// permitted.
static constexpr uint32_t kNoConversion = 0xffffffffu;
diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc
index 02c76c2..ee79102 100644
--- a/src/tint/sem/type_test.cc
+++ b/src/tint/sem/type_test.cc
@@ -44,22 +44,38 @@
const sem::Matrix* mat4x3_af = create<Matrix>(vec3_af, 4u);
const sem::Reference* ref_u32 =
create<Reference>(u32, ast::AddressSpace::kPrivate, ast::Access::kReadWrite);
- const sem::Struct* str = create<Struct>(nullptr,
- Sym("s"),
- StructMemberList{
- create<StructMember>(
- /* declaration */ nullptr,
- /* name */ Sym("x"),
- /* type */ f16,
- /* index */ 0u,
- /* offset */ 0u,
- /* align */ 4u,
- /* size */ 4u,
- /* location */ std::nullopt),
- },
- /* align*/ 4u,
- /* size*/ 4u,
- /* size_no_padding*/ 4u);
+ const sem::Struct* str_f16 = create<Struct>(nullptr,
+ Sym("str_f16"),
+ StructMemberList{
+ create<StructMember>(
+ /* declaration */ nullptr,
+ /* name */ Sym("x"),
+ /* type */ f16,
+ /* index */ 0u,
+ /* offset */ 0u,
+ /* align */ 4u,
+ /* size */ 4u,
+ /* location */ std::nullopt),
+ },
+ /* align*/ 4u,
+ /* size*/ 4u,
+ /* size_no_padding*/ 4u);
+ const sem::Struct* str_af = create<Struct>(nullptr,
+ Sym("str_af"),
+ StructMemberList{
+ create<StructMember>(
+ /* declaration */ nullptr,
+ /* name */ Sym("x"),
+ /* type */ af,
+ /* index */ 0u,
+ /* offset */ 0u,
+ /* align */ 4u,
+ /* size */ 4u,
+ /* location */ std::nullopt),
+ },
+ /* align*/ 4u,
+ /* size*/ 4u,
+ /* size_no_padding*/ 4u);
const sem::Array* arr_i32 = create<Array>(
/* element */ i32,
/* count */ ConstantArrayCount{5u},
@@ -109,8 +125,15 @@
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
- const sem::Array* arr_str = create<Array>(
- /* element */ str,
+ const sem::Array* arr_str_f16 = create<Array>(
+ /* element */ str_f16,
+ /* count */ ConstantArrayCount{5u},
+ /* align */ 4u,
+ /* size */ 5u * 4u,
+ /* stride */ 5u * 4u,
+ /* implicit_stride */ 5u * 4u);
+ const sem::Array* arr_str_af = create<Array>(
+ /* element */ str_af,
/* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
@@ -192,12 +215,12 @@
EXPECT_TYPE(Type::ElementOf(mat2x4_f32), vec4_f32);
EXPECT_TYPE(Type::ElementOf(mat4x2_f32), vec2_f32);
EXPECT_TYPE(Type::ElementOf(mat4x3_f16), vec3_f16);
- EXPECT_TYPE(Type::ElementOf(str), str);
+ EXPECT_TYPE(Type::ElementOf(str_f16), str_f16);
EXPECT_TYPE(Type::ElementOf(arr_i32), i32);
EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32);
EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16), mat4x3_f16);
EXPECT_TYPE(Type::ElementOf(arr_mat4x3_af), mat4x3_af);
- EXPECT_TYPE(Type::ElementOf(arr_str), str);
+ EXPECT_TYPE(Type::ElementOf(arr_str_f16), str_f16);
// With count
uint32_t count = 42;
@@ -237,7 +260,7 @@
EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), vec3_f16);
EXPECT_EQ(count, 4u);
count = 42;
- EXPECT_TYPE(Type::ElementOf(str, &count), str);
+ EXPECT_TYPE(Type::ElementOf(str_f16, &count), str_f16);
EXPECT_EQ(count, 1u);
count = 42;
EXPECT_TYPE(Type::ElementOf(arr_i32, &count), i32);
@@ -252,7 +275,7 @@
EXPECT_TYPE(Type::ElementOf(arr_mat4x3_af, &count), mat4x3_af);
EXPECT_EQ(count, 5u);
count = 42;
- EXPECT_TYPE(Type::ElementOf(arr_str, &count), str);
+ EXPECT_TYPE(Type::ElementOf(arr_str_f16, &count), str_f16);
EXPECT_EQ(count, 5u);
}
@@ -270,12 +293,12 @@
EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32), f32);
EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32), f32);
EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16), f16);
- EXPECT_TYPE(Type::DeepestElementOf(str), str);
+ EXPECT_TYPE(Type::DeepestElementOf(str_f16), str_f16);
EXPECT_TYPE(Type::DeepestElementOf(arr_i32), i32);
EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32), i32);
EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16), f16);
EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_af), af);
- EXPECT_TYPE(Type::DeepestElementOf(arr_str), str);
+ EXPECT_TYPE(Type::DeepestElementOf(arr_str_f16), str_f16);
// With count
uint32_t count = 42;
@@ -315,7 +338,7 @@
EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16, &count), f16);
EXPECT_EQ(count, 12u);
count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(str, &count), str);
+ EXPECT_TYPE(Type::DeepestElementOf(str_f16, &count), str_f16);
EXPECT_EQ(count, 1u);
count = 42;
EXPECT_TYPE(Type::DeepestElementOf(arr_i32, &count), i32);
@@ -330,7 +353,7 @@
EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_af, &count), af);
EXPECT_EQ(count, 60u);
count = 42;
- EXPECT_TYPE(Type::DeepestElementOf(arr_str, &count), str);
+ EXPECT_TYPE(Type::DeepestElementOf(arr_str_f16, &count), str_f16);
EXPECT_EQ(count, 5u);
}
@@ -525,5 +548,39 @@
arr_mat4x3_f16);
}
+TEST_F(TypeTest, HoldsAbstract) {
+ EXPECT_TRUE(af->HoldsAbstract());
+ EXPECT_TRUE(ai->HoldsAbstract());
+ EXPECT_FALSE(f32->HoldsAbstract());
+ EXPECT_FALSE(f16->HoldsAbstract());
+ EXPECT_FALSE(i32->HoldsAbstract());
+ EXPECT_FALSE(u32->HoldsAbstract());
+ EXPECT_FALSE(vec2_f32->HoldsAbstract());
+ EXPECT_FALSE(vec3_f32->HoldsAbstract());
+ EXPECT_FALSE(vec3_f16->HoldsAbstract());
+ EXPECT_FALSE(vec4_f32->HoldsAbstract());
+ EXPECT_FALSE(vec3_u32->HoldsAbstract());
+ EXPECT_FALSE(vec3_i32->HoldsAbstract());
+ EXPECT_TRUE(vec3_af->HoldsAbstract());
+ EXPECT_TRUE(vec3_ai->HoldsAbstract());
+ EXPECT_FALSE(mat2x4_f32->HoldsAbstract());
+ EXPECT_FALSE(mat3x4_f32->HoldsAbstract());
+ EXPECT_FALSE(mat4x2_f32->HoldsAbstract());
+ EXPECT_FALSE(mat4x3_f32->HoldsAbstract());
+ EXPECT_FALSE(mat4x3_f16->HoldsAbstract());
+ EXPECT_TRUE(mat4x3_af->HoldsAbstract());
+ EXPECT_FALSE(str_f16->HoldsAbstract());
+ EXPECT_TRUE(str_af->HoldsAbstract());
+ EXPECT_FALSE(arr_i32->HoldsAbstract());
+ EXPECT_TRUE(arr_ai->HoldsAbstract());
+ EXPECT_FALSE(arr_vec3_i32->HoldsAbstract());
+ EXPECT_TRUE(arr_vec3_ai->HoldsAbstract());
+ EXPECT_FALSE(arr_mat4x3_f16->HoldsAbstract());
+ EXPECT_FALSE(arr_mat4x3_f32->HoldsAbstract());
+ EXPECT_TRUE(arr_mat4x3_af->HoldsAbstract());
+ EXPECT_FALSE(arr_str_f16->HoldsAbstract());
+ EXPECT_TRUE(arr_str_af->HoldsAbstract());
+}
+
} // namespace
} // namespace tint::sem