[tint][type] Add Element(uint32_t)

Returns the element type with the given index.

Change-Id: Ib2a7e05e58a38c49adb853eb1b7b481cc5215f1a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/136603
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/type/array.cc b/src/tint/type/array.cc
index 45a488c..47e1ea2 100644
--- a/src/tint/type/array.cc
+++ b/src/tint/type/array.cc
@@ -114,6 +114,13 @@
     return {element_, n};
 }
 
+const Type* Array::Element(uint32_t index) const {
+    if (auto* count = count_->As<ConstantArrayCount>()) {
+        return index < count->value ? element_ : nullptr;
+    }
+    return element_;
+}
+
 Array* Array::Clone(CloneContext& ctx) const {
     auto* elem_ty = element_->Clone(ctx);
     auto* count = count_->Clone(ctx);
diff --git a/src/tint/type/array.h b/src/tint/type/array.h
index 7a1deeb..72a7479 100644
--- a/src/tint/type/array.h
+++ b/src/tint/type/array.h
@@ -101,6 +101,9 @@
     TypeAndCount Elements(const Type* type_if_invalid = nullptr,
                           uint32_t count_if_invalid = 0) const override;
 
+    /// @copydoc Type::Element
+    const Type* Element(uint32_t index) const override;
+
     /// @param ctx the clone context
     /// @returns a clone of this type
     Array* Clone(CloneContext& ctx) const override;
diff --git a/src/tint/type/matrix.cc b/src/tint/type/matrix.cc
index 55d508b..bce7335 100644
--- a/src/tint/type/matrix.cc
+++ b/src/tint/type/matrix.cc
@@ -74,6 +74,10 @@
     return {column_type_, columns_};
 }
 
+const Vector* Matrix::Element(uint32_t index) const {
+    return index < columns_ ? column_type_ : nullptr;
+}
+
 Matrix* Matrix::Clone(CloneContext& ctx) const {
     auto* col_ty = column_type_->Clone(ctx);
     return ctx.dst.mgr->Get<Matrix>(col_ty, columns_);
diff --git a/src/tint/type/matrix.h b/src/tint/type/matrix.h
index baf7a36..bce61ef 100644
--- a/src/tint/type/matrix.h
+++ b/src/tint/type/matrix.h
@@ -17,7 +17,7 @@
 
 #include <string>
 
-#include "src/tint/type/type.h"
+#include "src/tint/type/vector.h"
 
 // Forward declarations
 namespace tint::type {
@@ -69,6 +69,9 @@
     TypeAndCount Elements(const Type* type_if_invalid = nullptr,
                           uint32_t count_if_invalid = 0) const override;
 
+    /// @copydoc Type::Element
+    const Vector* Element(uint32_t index) const override;
+
     /// @param ctx the clone context
     /// @returns a clone of this type
     Matrix* Clone(CloneContext& ctx) const override;
diff --git a/src/tint/type/struct.cc b/src/tint/type/struct.cc
index 054072d..9b3bf49 100644
--- a/src/tint/type/struct.cc
+++ b/src/tint/type/struct.cc
@@ -165,6 +165,10 @@
     return {type_if_invalid, static_cast<uint32_t>(members_.Length())};
 }
 
+const Type* Struct::Element(uint32_t index) const {
+    return index < members_.Length() ? members_[index]->Type() : nullptr;
+}
+
 Struct* Struct::Clone(CloneContext& ctx) const {
     auto sym = ctx.dst.st->Register(name_.Name());
 
diff --git a/src/tint/type/struct.h b/src/tint/type/struct.h
index 5571bc4..720beca 100644
--- a/src/tint/type/struct.h
+++ b/src/tint/type/struct.h
@@ -162,6 +162,9 @@
     TypeAndCount Elements(const Type* type_if_invalid = nullptr,
                           uint32_t count_if_invalid = 0) const override;
 
+    /// @copydoc Type::Element
+    const Type* Element(uint32_t index) const override;
+
     /// @param ctx the clone context
     /// @returns a clone of this type
     Struct* Clone(CloneContext& ctx) const override;
diff --git a/src/tint/type/type.cc b/src/tint/type/type.cc
index 21d51ee..47f8bb4 100644
--- a/src/tint/type/type.cc
+++ b/src/tint/type/type.cc
@@ -246,6 +246,10 @@
     return {type_if_invalid, count_if_invalid};
 }
 
+const Type* Type::Element(uint32_t /* index */) const {
+    return nullptr;
+}
+
 const Type* Type::DeepestElement() const {
     const Type* ty = this;
     while (true) {
diff --git a/src/tint/type/type.h b/src/tint/type/type.h
index 6c607d4..d7c0e0a 100644
--- a/src/tint/type/type.h
+++ b/src/tint/type/type.h
@@ -196,6 +196,18 @@
     virtual TypeAndCount Elements(const Type* type_if_invalid = nullptr,
                                   uint32_t count_if_invalid = 0) const;
 
+    /// @param index the i'th element index to return
+    /// @returns The child element with the given index, or nullptr if the element does not exist.
+    ///
+    /// Examples:
+    ///  * Element(1) of `mat3x2<f32>` returns `vec2<f32>`.
+    ///  * Element(1) of `array<vec3<f32>, 5>` returns `vec3<f32>`.
+    ///  * Element(0) of `struct S { a : f32, b : i32 }` returns `f32`.
+    ///  * Element(0) of `f32` returns `nullptr`.
+    ///  * Element(3) of `vec3<f32>` returns `nullptr`.
+    ///  * Element(3) of `struct S { a : f32, b : i32 }` returns `nullptr`.
+    virtual const Type* Element(uint32_t index) const;
+
     /// @returns the most deeply nested element of the type. For non-composite types,
     /// DeepestElement() will return this type. Examples:
     ///  * Element() of `f32` returns `f32`.
diff --git a/src/tint/type/type_test.cc b/src/tint/type/type_test.cc
index d84f325..407c181 100644
--- a/src/tint/type/type_test.cc
+++ b/src/tint/type/type_test.cc
@@ -262,6 +262,43 @@
     EXPECT_EQ(arr_str_f16->Elements(arr_str_f16, 42), (TypeAndCount{str_f16, 5u}));
 }
 
+TEST_F(TypeTest, Element) {
+    EXPECT_TYPE(f32->Element(0), nullptr);
+    EXPECT_TYPE(f16->Element(1), nullptr);
+    EXPECT_TYPE(i32->Element(2), nullptr);
+    EXPECT_TYPE(u32->Element(3), nullptr);
+    EXPECT_TYPE(vec2_f32->Element(0), f32);
+    EXPECT_TYPE(vec2_f32->Element(1), f32);
+    EXPECT_TYPE(vec2_f32->Element(2), nullptr);
+    EXPECT_TYPE(vec3_f16->Element(0), f16);
+    EXPECT_TYPE(vec4_f32->Element(3), f32);
+    EXPECT_TYPE(vec4_f32->Element(4), nullptr);
+    EXPECT_TYPE(vec3_u32->Element(2), u32);
+    EXPECT_TYPE(vec3_u32->Element(3), nullptr);
+    EXPECT_TYPE(vec3_i32->Element(1), i32);
+    EXPECT_TYPE(vec3_i32->Element(4), nullptr);
+    EXPECT_TYPE(mat2x4_f32->Element(1), vec4_f32);
+    EXPECT_TYPE(mat2x4_f32->Element(2), nullptr);
+    EXPECT_TYPE(mat4x2_f32->Element(3), vec2_f32);
+    EXPECT_TYPE(mat4x2_f32->Element(4), nullptr);
+    EXPECT_TYPE(mat4x3_f16->Element(1), vec3_f16);
+    EXPECT_TYPE(mat4x3_f16->Element(5), nullptr);
+    EXPECT_TYPE(str_f16->Element(0), f16);
+    EXPECT_TYPE(str_f16->Element(1), nullptr);
+    EXPECT_TYPE(arr_i32->Element(0), i32);
+    EXPECT_TYPE(arr_i32->Element(4), i32);
+    EXPECT_TYPE(arr_i32->Element(5), nullptr);
+    EXPECT_TYPE(arr_vec3_i32->Element(4), vec3_i32);
+    EXPECT_TYPE(arr_vec3_i32->Element(5), nullptr);
+    EXPECT_TYPE(arr_mat4x3_f16->Element(1), mat4x3_f16);
+    EXPECT_TYPE(arr_mat4x3_f16->Element(10), nullptr);
+    EXPECT_TYPE(arr_mat4x3_af->Element(2), mat4x3_af);
+    EXPECT_TYPE(arr_mat4x3_af->Element(6), nullptr);
+    EXPECT_TYPE(arr_str_f16->Element(0), str_f16);
+    EXPECT_TYPE(arr_str_f16->Element(1), str_f16);
+    EXPECT_TYPE(arr_str_f16->Element(10), nullptr);
+}
+
 TEST_F(TypeTest, DeepestElement) {
     EXPECT_TYPE(f32->DeepestElement(), f32);
     EXPECT_TYPE(f16->DeepestElement(), f16);
diff --git a/src/tint/type/vector.cc b/src/tint/type/vector.cc
index 126223a..49f0f1b 100644
--- a/src/tint/type/vector.cc
+++ b/src/tint/type/vector.cc
@@ -82,4 +82,8 @@
     return {subtype_, width_};
 }
 
+const Type* Vector::Element(uint32_t index) const {
+    return index < width_ ? subtype_ : nullptr;
+}
+
 }  // namespace tint::type
diff --git a/src/tint/type/vector.h b/src/tint/type/vector.h
index e40715e..e83d917 100644
--- a/src/tint/type/vector.h
+++ b/src/tint/type/vector.h
@@ -68,6 +68,9 @@
     TypeAndCount Elements(const Type* type_if_invalid = nullptr,
                           uint32_t count_if_invalid = 0) const override;
 
+    /// @copydoc Type::Element
+    const Type* Element(uint32_t index) const override;
+
     /// @param ctx the clone context
     /// @returns a clone of this type
     Vector* Clone(CloneContext& ctx) const override;