tint/sem: Add Type::DeepestElementOf()
Like `ElementOf()`, but returns the most nested element type.
Change-Id: Ieb97f830293d4714d0d5ddc0c9304e41e994f61b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94324
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc
index 40666e3..d92b236 100644
--- a/src/tint/sem/type.cc
+++ b/src/tint/sem/type.cc
@@ -229,9 +229,29 @@
*count = a->Count();
}
return a->ElemType();
+ },
+ [&](Default) {
+ if (count) {
+ *count = 0;
+ }
+ return nullptr;
});
}
+const Type* Type::DeepestElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
+ auto el_ty = ElementOf(ty, count);
+ while (el_ty && ty != el_ty) {
+ ty = el_ty;
+
+ uint32_t n = 0;
+ el_ty = ElementOf(ty, &n);
+ if (count) {
+ *count *= n;
+ }
+ }
+ return el_ty;
+}
+
const sem::Type* Type::Common(Type const* const* types, size_t count) {
if (count == 0) {
return nullptr;
diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h
index 9987637..ac863207 100644
--- a/src/tint/sem/type.h
+++ b/src/tint/sem/type.h
@@ -130,11 +130,19 @@
static uint32_t ConversionRank(const Type* from, const Type* to);
/// @param ty the type to obtain the element type from
- /// @param count if not null, then this is assigned the number of elements in the type
- /// @returns `ty` if `ty` is an abstract or scalar, the element type if ty is a vector, matrix
- /// or array, otherwise nullptr.
+ /// @param count if not null, then this is assigned the number of child elements in the type.
+ /// For example, the count of an `array<vec3<f32>, 5>` type would be 5.
+ /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector,
+ /// matrix or array, otherwise nullptr.
static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr);
+ /// @param ty the type to obtain the deepest element type from
+ /// @param count if not null, then this is assigned the full number of most deeply nested
+ /// elements in the type. For example, the count of an `array<vec3<f32>, 5>` type would be 15.
+ /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector,
+ /// matrix, or the deepest element type if ty is an array, otherwise nullptr.
+ static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr);
+
/// @param types a pointer to a list of `const Type*`.
/// @param count the number of types in `types`.
/// @returns the lowest-ranking type that all types in `types` can be implicitly converted to,
diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc
index c11efea..e604453 100644
--- a/src/tint/sem/type_test.cc
+++ b/src/tint/sem/type_test.cc
@@ -104,6 +104,20 @@
auto* mat2x4_f32 = create<Matrix>(vec4_f32, 2u);
auto* mat4x2_f32 = create<Matrix>(vec2_f32, 4u);
auto* mat4x3_f16 = create<Matrix>(vec3_f16, 4u);
+ auto* str = create<Struct>(nullptr, Sym("s"),
+ StructMemberList{
+ create<StructMember>(
+ /* declaration */ nullptr,
+ /* name */ Sym("x"),
+ /* type */ f16,
+ /* index */ 0u,
+ /* offset */ 0u,
+ /* align */ 4u,
+ /* size */ 4u),
+ },
+ /* align*/ 4u,
+ /* size*/ 4u,
+ /* size_no_padding*/ 4u);
auto* arr_i32 = create<Array>(
/* element */ i32,
/* count */ 5u,
@@ -111,6 +125,27 @@
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
+ auto* arr_vec3_i32 = create<Array>(
+ /* element */ vec3_i32,
+ /* count */ 5u,
+ /* align */ 16u,
+ /* size */ 5u * 16u,
+ /* stride */ 5u * 16u,
+ /* implicit_stride */ 5u * 16u);
+ auto* arr_mat4x3_f16 = create<Array>(
+ /* element */ mat4x3_f16,
+ /* count */ 5u,
+ /* align */ 64u,
+ /* size */ 5u * 64u,
+ /* stride */ 5u * 64u,
+ /* implicit_stride */ 5u * 64u);
+ auto* arr_str = create<Array>(
+ /* element */ str,
+ /* count */ 5u,
+ /* align */ 4u,
+ /* size */ 5u * 4u,
+ /* stride */ 5u * 4u,
+ /* implicit_stride */ 5u * 4u);
// No count
EXPECT_TYPE(Type::ElementOf(f32), f32);
@@ -125,48 +160,193 @@
EXPECT_TYPE(Type::ElementOf(mat2x4_f32), f32);
EXPECT_TYPE(Type::ElementOf(mat4x2_f32), f32);
EXPECT_TYPE(Type::ElementOf(mat4x3_f16), f16);
+ EXPECT_TYPE(Type::ElementOf(str), nullptr);
EXPECT_TYPE(Type::ElementOf(arr_i32), i32);
+ EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32);
+ EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16), mat4x3_f16);
+ EXPECT_TYPE(Type::ElementOf(arr_str), str);
// With count
- uint32_t count = 0;
+ uint32_t count = 42;
EXPECT_TYPE(Type::ElementOf(f32, &count), f32);
EXPECT_EQ(count, 1u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(f16, &count), f16);
EXPECT_EQ(count, 1u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(i32, &count), i32);
EXPECT_EQ(count, 1u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(u32, &count), u32);
EXPECT_EQ(count, 1u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(vec2_f32, &count), f32);
EXPECT_EQ(count, 2u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(vec3_f16, &count), f16);
EXPECT_EQ(count, 3u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(vec4_f32, &count), f32);
EXPECT_EQ(count, 4u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(vec3_u32, &count), u32);
EXPECT_EQ(count, 3u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32);
EXPECT_EQ(count, 3u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), f32);
EXPECT_EQ(count, 8u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), f32);
EXPECT_EQ(count, 8u);
- count = 0;
+ count = 42;
EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), f16);
EXPECT_EQ(count, 12u);
- count = 0;
+ count = 42;
+ EXPECT_TYPE(Type::ElementOf(str, &count), nullptr);
+ EXPECT_EQ(count, 0u);
+ count = 42;
EXPECT_TYPE(Type::ElementOf(arr_i32, &count), i32);
EXPECT_EQ(count, 5u);
+ count = 42;
+ EXPECT_TYPE(Type::ElementOf(arr_vec3_i32, &count), vec3_i32);
+ EXPECT_EQ(count, 5u);
+ count = 42;
+ EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16, &count), mat4x3_f16);
+ EXPECT_EQ(count, 5u);
+ count = 42;
+ EXPECT_TYPE(Type::ElementOf(arr_str, &count), str);
+ EXPECT_EQ(count, 5u);
+}
+
+TEST_F(TypeTest, DeepestElementOf) {
+ auto* f32 = create<F32>();
+ auto* f16 = create<F16>();
+ auto* i32 = create<I32>();
+ auto* u32 = create<U32>();
+ auto* vec2_f32 = create<Vector>(f32, 2u);
+ auto* vec3_f16 = create<Vector>(f16, 3u);
+ auto* vec4_f32 = create<Vector>(f32, 4u);
+ auto* vec3_u32 = create<Vector>(u32, 3u);
+ auto* vec3_i32 = create<Vector>(i32, 3u);
+ auto* mat2x4_f32 = create<Matrix>(vec4_f32, 2u);
+ auto* mat4x2_f32 = create<Matrix>(vec2_f32, 4u);
+ auto* mat4x3_f16 = create<Matrix>(vec3_f16, 4u);
+ auto* str = create<Struct>(nullptr, Sym("s"),
+ StructMemberList{
+ create<StructMember>(
+ /* declaration */ nullptr,
+ /* name */ Sym("x"),
+ /* type */ f16,
+ /* index */ 0u,
+ /* offset */ 0u,
+ /* align */ 4u,
+ /* size */ 4u),
+ },
+ /* align*/ 4u,
+ /* size*/ 4u,
+ /* size_no_padding*/ 4u);
+ auto* arr_i32 = create<Array>(
+ /* element */ i32,
+ /* count */ 5u,
+ /* align */ 4u,
+ /* size */ 5u * 4u,
+ /* stride */ 5u * 4u,
+ /* implicit_stride */ 5u * 4u);
+ auto* arr_vec3_i32 = create<Array>(
+ /* element */ vec3_i32,
+ /* count */ 5u,
+ /* align */ 16u,
+ /* size */ 5u * 16u,
+ /* stride */ 5u * 16u,
+ /* implicit_stride */ 5u * 16u);
+ auto* arr_mat4x3_f16 = create<Array>(
+ /* element */ mat4x3_f16,
+ /* count */ 5u,
+ /* align */ 64u,
+ /* size */ 5u * 64u,
+ /* stride */ 5u * 64u,
+ /* implicit_stride */ 5u * 64u);
+ auto* arr_str = create<Array>(
+ /* element */ str,
+ /* count */ 5u,
+ /* align */ 4u,
+ /* size */ 5u * 4u,
+ /* stride */ 5u * 4u,
+ /* implicit_stride */ 5u * 4u);
+
+ // No count
+ EXPECT_TYPE(Type::DeepestElementOf(f32), f32);
+ EXPECT_TYPE(Type::DeepestElementOf(f16), f16);
+ EXPECT_TYPE(Type::DeepestElementOf(i32), i32);
+ EXPECT_TYPE(Type::DeepestElementOf(u32), u32);
+ EXPECT_TYPE(Type::DeepestElementOf(vec2_f32), f32);
+ EXPECT_TYPE(Type::DeepestElementOf(vec3_f16), f16);
+ EXPECT_TYPE(Type::DeepestElementOf(vec4_f32), f32);
+ EXPECT_TYPE(Type::DeepestElementOf(vec3_u32), u32);
+ EXPECT_TYPE(Type::DeepestElementOf(vec3_i32), i32);
+ EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32), f32);
+ EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32), f32);
+ EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16), f16);
+ EXPECT_TYPE(Type::DeepestElementOf(str), nullptr);
+ EXPECT_TYPE(Type::DeepestElementOf(arr_i32), i32);
+ EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32), i32);
+ EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16), f16);
+ EXPECT_TYPE(Type::DeepestElementOf(arr_str), nullptr);
+
+ // With count
+ uint32_t count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(f32, &count), f32);
+ EXPECT_EQ(count, 1u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(f16, &count), f16);
+ EXPECT_EQ(count, 1u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(i32, &count), i32);
+ EXPECT_EQ(count, 1u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(u32, &count), u32);
+ EXPECT_EQ(count, 1u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(vec2_f32, &count), f32);
+ EXPECT_EQ(count, 2u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(vec3_f16, &count), f16);
+ EXPECT_EQ(count, 3u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(vec4_f32, &count), f32);
+ EXPECT_EQ(count, 4u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(vec3_u32, &count), u32);
+ EXPECT_EQ(count, 3u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(vec3_i32, &count), i32);
+ EXPECT_EQ(count, 3u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32, &count), f32);
+ EXPECT_EQ(count, 8u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32, &count), f32);
+ EXPECT_EQ(count, 8u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16, &count), f16);
+ EXPECT_EQ(count, 12u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(str, &count), nullptr);
+ EXPECT_EQ(count, 0u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(arr_i32, &count), i32);
+ EXPECT_EQ(count, 5u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32, &count), i32);
+ EXPECT_EQ(count, 15u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16, &count), f16);
+ EXPECT_EQ(count, 60u);
+ count = 42;
+ EXPECT_TYPE(Type::DeepestElementOf(arr_str, &count), nullptr);
+ EXPECT_EQ(count, 0u);
}
TEST_F(TypeTest, Common2) {