tint/sem: Add Type::DeepestElementOf()

Like `ElementOf()`, but returns the most nested element type.

Change-Id: Ieb97f830293d4714d0d5ddc0c9304e41e994f61b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94324
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc
index 40666e3..d92b236 100644
--- a/src/tint/sem/type.cc
+++ b/src/tint/sem/type.cc
@@ -229,9 +229,29 @@
                 *count = a->Count();
             }
             return a->ElemType();
+        },
+        [&](Default) {
+            if (count) {
+                *count = 0;
+            }
+            return nullptr;
         });
 }
 
+const Type* Type::DeepestElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
+    auto el_ty = ElementOf(ty, count);
+    while (el_ty && ty != el_ty) {
+        ty = el_ty;
+
+        uint32_t n = 0;
+        el_ty = ElementOf(ty, &n);
+        if (count) {
+            *count *= n;
+        }
+    }
+    return el_ty;
+}
+
 const sem::Type* Type::Common(Type const* const* types, size_t count) {
     if (count == 0) {
         return nullptr;
diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h
index 9987637..ac863207 100644
--- a/src/tint/sem/type.h
+++ b/src/tint/sem/type.h
@@ -130,11 +130,19 @@
     static uint32_t ConversionRank(const Type* from, const Type* to);
 
     /// @param ty the type to obtain the element type from
-    /// @param count if not null, then this is assigned the number of elements in the type
-    /// @returns `ty` if `ty` is an abstract or scalar, the element type if ty is a vector, matrix
-    /// or array, otherwise nullptr.
+    /// @param count if not null, then this is assigned the number of child elements in the type.
+    /// For example, the count of an `array<vec3<f32>, 5>` type would be 5.
+    /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector,
+    /// matrix or array, otherwise nullptr.
     static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr);
 
+    /// @param ty the type to obtain the deepest element type from
+    /// @param count if not null, then this is assigned the full number of most deeply nested
+    /// elements in the type. For example, the count of an `array<vec3<f32>, 5>` type would be 15.
+    /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector,
+    /// matrix, or the deepest element type if ty is an array, otherwise nullptr.
+    static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr);
+
     /// @param types a pointer to a list of `const Type*`.
     /// @param count the number of types in `types`.
     /// @returns the lowest-ranking type that all types in `types` can be implicitly converted to,
diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc
index c11efea..e604453 100644
--- a/src/tint/sem/type_test.cc
+++ b/src/tint/sem/type_test.cc
@@ -104,6 +104,20 @@
     auto* mat2x4_f32 = create<Matrix>(vec4_f32, 2u);
     auto* mat4x2_f32 = create<Matrix>(vec2_f32, 4u);
     auto* mat4x3_f16 = create<Matrix>(vec3_f16, 4u);
+    auto* str = create<Struct>(nullptr, Sym("s"),
+                               StructMemberList{
+                                   create<StructMember>(
+                                       /* declaration */ nullptr,
+                                       /* name */ Sym("x"),
+                                       /* type */ f16,
+                                       /* index */ 0u,
+                                       /* offset */ 0u,
+                                       /* align */ 4u,
+                                       /* size */ 4u),
+                               },
+                               /* align*/ 4u,
+                               /* size*/ 4u,
+                               /* size_no_padding*/ 4u);
     auto* arr_i32 = create<Array>(
         /* element */ i32,
         /* count */ 5u,
@@ -111,6 +125,27 @@
         /* size */ 5u * 4u,
         /* stride */ 5u * 4u,
         /* implicit_stride */ 5u * 4u);
+    auto* arr_vec3_i32 = create<Array>(
+        /* element */ vec3_i32,
+        /* count */ 5u,
+        /* align */ 16u,
+        /* size */ 5u * 16u,
+        /* stride */ 5u * 16u,
+        /* implicit_stride */ 5u * 16u);
+    auto* arr_mat4x3_f16 = create<Array>(
+        /* element */ mat4x3_f16,
+        /* count */ 5u,
+        /* align */ 64u,
+        /* size */ 5u * 64u,
+        /* stride */ 5u * 64u,
+        /* implicit_stride */ 5u * 64u);
+    auto* arr_str = create<Array>(
+        /* element */ str,
+        /* count */ 5u,
+        /* align */ 4u,
+        /* size */ 5u * 4u,
+        /* stride */ 5u * 4u,
+        /* implicit_stride */ 5u * 4u);
 
     // No count
     EXPECT_TYPE(Type::ElementOf(f32), f32);
@@ -125,48 +160,193 @@
     EXPECT_TYPE(Type::ElementOf(mat2x4_f32), f32);
     EXPECT_TYPE(Type::ElementOf(mat4x2_f32), f32);
     EXPECT_TYPE(Type::ElementOf(mat4x3_f16), f16);
+    EXPECT_TYPE(Type::ElementOf(str), nullptr);
     EXPECT_TYPE(Type::ElementOf(arr_i32), i32);
+    EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32);
+    EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16), mat4x3_f16);
+    EXPECT_TYPE(Type::ElementOf(arr_str), str);
 
     // With count
-    uint32_t count = 0;
+    uint32_t count = 42;
     EXPECT_TYPE(Type::ElementOf(f32, &count), f32);
     EXPECT_EQ(count, 1u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(f16, &count), f16);
     EXPECT_EQ(count, 1u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(i32, &count), i32);
     EXPECT_EQ(count, 1u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(u32, &count), u32);
     EXPECT_EQ(count, 1u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(vec2_f32, &count), f32);
     EXPECT_EQ(count, 2u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(vec3_f16, &count), f16);
     EXPECT_EQ(count, 3u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(vec4_f32, &count), f32);
     EXPECT_EQ(count, 4u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(vec3_u32, &count), u32);
     EXPECT_EQ(count, 3u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32);
     EXPECT_EQ(count, 3u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), f32);
     EXPECT_EQ(count, 8u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), f32);
     EXPECT_EQ(count, 8u);
-    count = 0;
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), f16);
     EXPECT_EQ(count, 12u);
-    count = 0;
+    count = 42;
+    EXPECT_TYPE(Type::ElementOf(str, &count), nullptr);
+    EXPECT_EQ(count, 0u);
+    count = 42;
     EXPECT_TYPE(Type::ElementOf(arr_i32, &count), i32);
     EXPECT_EQ(count, 5u);
+    count = 42;
+    EXPECT_TYPE(Type::ElementOf(arr_vec3_i32, &count), vec3_i32);
+    EXPECT_EQ(count, 5u);
+    count = 42;
+    EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16, &count), mat4x3_f16);
+    EXPECT_EQ(count, 5u);
+    count = 42;
+    EXPECT_TYPE(Type::ElementOf(arr_str, &count), str);
+    EXPECT_EQ(count, 5u);
+}
+
+TEST_F(TypeTest, DeepestElementOf) {
+    auto* f32 = create<F32>();
+    auto* f16 = create<F16>();
+    auto* i32 = create<I32>();
+    auto* u32 = create<U32>();
+    auto* vec2_f32 = create<Vector>(f32, 2u);
+    auto* vec3_f16 = create<Vector>(f16, 3u);
+    auto* vec4_f32 = create<Vector>(f32, 4u);
+    auto* vec3_u32 = create<Vector>(u32, 3u);
+    auto* vec3_i32 = create<Vector>(i32, 3u);
+    auto* mat2x4_f32 = create<Matrix>(vec4_f32, 2u);
+    auto* mat4x2_f32 = create<Matrix>(vec2_f32, 4u);
+    auto* mat4x3_f16 = create<Matrix>(vec3_f16, 4u);
+    auto* str = create<Struct>(nullptr, Sym("s"),
+                               StructMemberList{
+                                   create<StructMember>(
+                                       /* declaration */ nullptr,
+                                       /* name */ Sym("x"),
+                                       /* type */ f16,
+                                       /* index */ 0u,
+                                       /* offset */ 0u,
+                                       /* align */ 4u,
+                                       /* size */ 4u),
+                               },
+                               /* align*/ 4u,
+                               /* size*/ 4u,
+                               /* size_no_padding*/ 4u);
+    auto* arr_i32 = create<Array>(
+        /* element */ i32,
+        /* count */ 5u,
+        /* align */ 4u,
+        /* size */ 5u * 4u,
+        /* stride */ 5u * 4u,
+        /* implicit_stride */ 5u * 4u);
+    auto* arr_vec3_i32 = create<Array>(
+        /* element */ vec3_i32,
+        /* count */ 5u,
+        /* align */ 16u,
+        /* size */ 5u * 16u,
+        /* stride */ 5u * 16u,
+        /* implicit_stride */ 5u * 16u);
+    auto* arr_mat4x3_f16 = create<Array>(
+        /* element */ mat4x3_f16,
+        /* count */ 5u,
+        /* align */ 64u,
+        /* size */ 5u * 64u,
+        /* stride */ 5u * 64u,
+        /* implicit_stride */ 5u * 64u);
+    auto* arr_str = create<Array>(
+        /* element */ str,
+        /* count */ 5u,
+        /* align */ 4u,
+        /* size */ 5u * 4u,
+        /* stride */ 5u * 4u,
+        /* implicit_stride */ 5u * 4u);
+
+    // No count
+    EXPECT_TYPE(Type::DeepestElementOf(f32), f32);
+    EXPECT_TYPE(Type::DeepestElementOf(f16), f16);
+    EXPECT_TYPE(Type::DeepestElementOf(i32), i32);
+    EXPECT_TYPE(Type::DeepestElementOf(u32), u32);
+    EXPECT_TYPE(Type::DeepestElementOf(vec2_f32), f32);
+    EXPECT_TYPE(Type::DeepestElementOf(vec3_f16), f16);
+    EXPECT_TYPE(Type::DeepestElementOf(vec4_f32), f32);
+    EXPECT_TYPE(Type::DeepestElementOf(vec3_u32), u32);
+    EXPECT_TYPE(Type::DeepestElementOf(vec3_i32), i32);
+    EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32), f32);
+    EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32), f32);
+    EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16), f16);
+    EXPECT_TYPE(Type::DeepestElementOf(str), nullptr);
+    EXPECT_TYPE(Type::DeepestElementOf(arr_i32), i32);
+    EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32), i32);
+    EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16), f16);
+    EXPECT_TYPE(Type::DeepestElementOf(arr_str), nullptr);
+
+    // With count
+    uint32_t count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(f32, &count), f32);
+    EXPECT_EQ(count, 1u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(f16, &count), f16);
+    EXPECT_EQ(count, 1u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(i32, &count), i32);
+    EXPECT_EQ(count, 1u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(u32, &count), u32);
+    EXPECT_EQ(count, 1u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(vec2_f32, &count), f32);
+    EXPECT_EQ(count, 2u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(vec3_f16, &count), f16);
+    EXPECT_EQ(count, 3u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(vec4_f32, &count), f32);
+    EXPECT_EQ(count, 4u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(vec3_u32, &count), u32);
+    EXPECT_EQ(count, 3u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(vec3_i32, &count), i32);
+    EXPECT_EQ(count, 3u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32, &count), f32);
+    EXPECT_EQ(count, 8u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32, &count), f32);
+    EXPECT_EQ(count, 8u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16, &count), f16);
+    EXPECT_EQ(count, 12u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(str, &count), nullptr);
+    EXPECT_EQ(count, 0u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(arr_i32, &count), i32);
+    EXPECT_EQ(count, 5u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32, &count), i32);
+    EXPECT_EQ(count, 15u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16, &count), f16);
+    EXPECT_EQ(count, 60u);
+    count = 42;
+    EXPECT_TYPE(Type::DeepestElementOf(arr_str, &count), nullptr);
+    EXPECT_EQ(count, 0u);
 }
 
 TEST_F(TypeTest, Common2) {