tint: Add new sem::Type helpers

Add:
• sem::Type::is_abstract_or_scalar()
• sem::Type::ElementOf()

Use these to clean up some code in src/tint/sem/constant.cc.

Bug: tint:1504
Change-Id: I78e06b580a750c97ac654af4b0b364ddd3de6596
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90534
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 32c0e93..38970a5 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2636,6 +2636,25 @@
     /// the type declaration has no resolved type.
     const sem::Type* TypeOf(const ast::TypeDecl* type_decl) const;
 
+    /// @param type a type
+    /// @returns the name for `type` that closely resembles how it would be
+    /// declared in WGSL.
+    std::string FriendlyName(const ast::Type* type) {
+        return type ? type->FriendlyName(Symbols()) : "<null>";
+    }
+
+    /// @param type a type
+    /// @returns the name for `type` that closely resembles how it would be
+    /// declared in WGSL.
+    std::string FriendlyName(const sem::Type* type) {
+        return type ? type->FriendlyName(Symbols()) : "<null>";
+    }
+
+    /// Overload of FriendlyName, which removes an ambiguity when passing nullptr.
+    /// Simplifies test code.
+    /// @returns "<null>"
+    std::string FriendlyName(std::nullptr_t) { return "<null>"; }
+
     /// Wraps the ast::Expression in a statement. This is used by tests that
     /// construct a partial AST and require the Resolver to reach these
     /// nodes.
diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc
index 1c4dc58..98c724c 100644
--- a/src/tint/sem/constant.cc
+++ b/src/tint/sem/constant.cc
@@ -14,7 +14,6 @@
 
 #include "src/tint/sem/constant.h"
 
-#include <functional>
 #include <utility>
 
 #include "src/tint/debug.h"
@@ -25,24 +24,19 @@
 
 namespace {
 
-const Type* ElemType(const Type* ty, size_t num_elements) {
+const Type* CheckElemType(const Type* ty, size_t num_scalars) {
     diag::List diag;
-    if (ty->is_scalar()) {
-        if (num_elements != 1) {
-            TINT_ICE(Semantic, diag) << "sem::Constant() type <-> num_element mismatch. type: '"
-                                     << ty->TypeInfo().name << "' num_elements: " << num_elements;
+    if (ty->is_abstract_or_scalar() || ty->IsAnyOf<Vector, Matrix>()) {
+        uint32_t count = 0;
+        auto* el_ty = Type::ElementOf(ty, &count);
+        if (num_scalars != count) {
+            TINT_ICE(Semantic, diag) << "sem::Constant() type <-> scalar mismatch. type: '"
+                                     << ty->TypeInfo().name << "' scalar: " << num_scalars;
         }
-        return ty;
+        TINT_ASSERT(Semantic, el_ty->is_abstract_or_scalar());
+        return el_ty;
     }
-    if (auto* vec = ty->As<Vector>()) {
-        if (num_elements != vec->Width()) {
-            TINT_ICE(Semantic, diag) << "sem::Constant() type <-> num_element mismatch. type: '"
-                                     << ty->TypeInfo().name << "' num_elements: " << num_elements;
-        }
-        TINT_ASSERT(Semantic, vec->type()->is_scalar());
-        return vec->type();
-    }
-    TINT_UNREACHABLE(Semantic, diag) << "Unsupported sem::Constant type";
+    TINT_UNREACHABLE(Semantic, diag) << "Unsupported sem::Constant type: " << ty->TypeInfo().name;
     return nullptr;
 }
 
@@ -51,7 +45,7 @@
 Constant::Constant() {}
 
 Constant::Constant(const sem::Type* ty, Scalars els)
-    : type_(ty), elem_type_(ElemType(ty, els.size())), elems_(std::move(els)) {}
+    : type_(ty), elem_type_(CheckElemType(ty, els.size())), elems_(std::move(els)) {}
 
 Constant::Constant(const Constant&) = default;
 
diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc
index 06f2a8d..df38a2b 100644
--- a/src/tint/sem/type.cc
+++ b/src/tint/sem/type.cc
@@ -16,6 +16,7 @@
 
 #include "src/tint/sem/abstract_float.h"
 #include "src/tint/sem/abstract_int.h"
+#include "src/tint/sem/array.h"
 #include "src/tint/sem/bool.h"
 #include "src/tint/sem/f16.h"
 #include "src/tint/sem/f32.h"
@@ -70,6 +71,10 @@
     return IsAnyOf<F16, F32, U32, I32, Bool>();
 }
 
+bool Type::is_abstract_or_scalar() const {
+    return IsAnyOf<F16, F32, U32, I32, Bool, AbstractNumeric>();
+}
+
 bool Type::is_numeric_scalar() const {
     return IsAnyOf<F16, F32, U32, I32>();
 }
@@ -198,4 +203,33 @@
         [&](Default) { return kNoConversion; });
 }
 
+const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
+    if (ty->is_abstract_or_scalar()) {
+        if (count) {
+            *count = 1;
+        }
+        return ty;
+    }
+    return Switch(
+        ty,  //
+        [&](const Vector* v) {
+            if (count) {
+                *count = v->Width();
+            }
+            return v->type();
+        },
+        [&](const Matrix* m) {
+            if (count) {
+                *count = m->columns() * m->rows();
+            }
+            return m->type();
+        },
+        [&](const Array* a) {
+            if (count) {
+                *count = a->Count();
+            }
+            return a->ElemType();
+        });
+}
+
 }  // namespace tint::sem
diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h
index 56ec64b..2a7ad96 100644
--- a/src/tint/sem/type.h
+++ b/src/tint/sem/type.h
@@ -71,6 +71,8 @@
 
     /// @returns true if this type is a scalar
     bool is_scalar() const;
+    /// @returns true if this type is a scalar or an abstract numeric
+    bool is_abstract_or_scalar() const;
     /// @returns true if this type is a numeric scalar
     bool is_numeric_scalar() const;
     /// @returns true if this type is a float scalar
@@ -127,6 +129,12 @@
     /// @see https://www.w3.org/TR/WGSL/#conversion-rank
     static uint32_t ConversionRank(const Type* from, const Type* to);
 
+    /// @param ty the type to obtain the element type from
+    /// @param count if not null, then this is assigned the number of elements in the type
+    /// @returns `ty` if `ty` is an abstract or scalar, the element type if ty is a vector, matrix
+    /// or array, otherwise nullptr.
+    static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr);
+
   protected:
     Type();
 };
diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc
index 149ecc7..62fc82f 100644
--- a/src/tint/sem/type_test.cc
+++ b/src/tint/sem/type_test.cc
@@ -91,5 +91,93 @@
     EXPECT_EQ(Type::ConversionRank(f16, ai), Type::kNoConversion);
 }
 
+/// Helper macro for testing that a semantic type was as expected
+#define EXPECT_TYPE(GOT, EXPECT)                              \
+    if ((GOT) != (EXPECT)) {                                  \
+        FAIL() << #GOT " != " #EXPECT "\n"                    \
+               << "  " #GOT ": " << FriendlyName(GOT) << "\n" \
+               << "  " #EXPECT ": " << FriendlyName(EXPECT);  \
+    }                                                         \
+    do {                                                      \
+    } while (false)
+
+TEST_F(TypeTest, ElementOf) {
+    auto* f32 = create<F32>();
+    auto* f16 = create<F16>();
+    auto* i32 = create<I32>();
+    auto* u32 = create<U32>();
+    auto* vec2_f32 = create<Vector>(f32, 2u);
+    auto* vec3_f16 = create<Vector>(f16, 3u);
+    auto* vec4_f32 = create<Vector>(f32, 4u);
+    auto* vec3_u32 = create<Vector>(u32, 3u);
+    auto* vec3_i32 = create<Vector>(i32, 3u);
+    auto* mat2x4_f32 = create<Matrix>(vec4_f32, 2u);
+    auto* mat4x2_f32 = create<Matrix>(vec2_f32, 4u);
+    auto* mat4x3_f16 = create<Matrix>(vec3_f16, 4u);
+    auto* arr_i32 = create<Array>(
+        /* element */ i32,
+        /* count */ 5u,
+        /* align */ 4u,
+        /* size */ 5u * 4u,
+        /* stride */ 5u * 4u,
+        /* implicit_stride */ 5u * 4u);
+
+    // No count
+    EXPECT_TYPE(Type::ElementOf(f32), f32);
+    EXPECT_TYPE(Type::ElementOf(f16), f16);
+    EXPECT_TYPE(Type::ElementOf(i32), i32);
+    EXPECT_TYPE(Type::ElementOf(u32), u32);
+    EXPECT_TYPE(Type::ElementOf(vec2_f32), f32);
+    EXPECT_TYPE(Type::ElementOf(vec3_f16), f16);
+    EXPECT_TYPE(Type::ElementOf(vec4_f32), f32);
+    EXPECT_TYPE(Type::ElementOf(vec3_u32), u32);
+    EXPECT_TYPE(Type::ElementOf(vec3_i32), i32);
+    EXPECT_TYPE(Type::ElementOf(mat2x4_f32), f32);
+    EXPECT_TYPE(Type::ElementOf(mat4x2_f32), f32);
+    EXPECT_TYPE(Type::ElementOf(mat4x3_f16), f16);
+    EXPECT_TYPE(Type::ElementOf(arr_i32), i32);
+
+    // With count
+    uint32_t count = 0;
+    EXPECT_TYPE(Type::ElementOf(f32, &count), f32);
+    EXPECT_EQ(count, 1u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(f16, &count), f16);
+    EXPECT_EQ(count, 1u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(i32, &count), i32);
+    EXPECT_EQ(count, 1u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(u32, &count), u32);
+    EXPECT_EQ(count, 1u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(vec2_f32, &count), f32);
+    EXPECT_EQ(count, 2u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(vec3_f16, &count), f16);
+    EXPECT_EQ(count, 3u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(vec4_f32, &count), f32);
+    EXPECT_EQ(count, 4u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(vec3_u32, &count), u32);
+    EXPECT_EQ(count, 3u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32);
+    EXPECT_EQ(count, 3u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), f32);
+    EXPECT_EQ(count, 8u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), f32);
+    EXPECT_EQ(count, 8u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), f16);
+    EXPECT_EQ(count, 12u);
+    count = 0;
+    EXPECT_TYPE(Type::ElementOf(arr_i32, &count), i32);
+    EXPECT_EQ(count, 5u);
+}
+
 }  // namespace
 }  // namespace tint::sem