tint/sem: Add Type::HoldsAbstract()

Tells you if there's an abstract numeric somewhere in the type.

Change-Id: I0573be9e57ec48f2fa63c46944214e7f5be7d67c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/104823
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: dan sinclair <dsinclair@google.com>
diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc
index 8ace823..fed75e6 100644
--- a/src/tint/sem/type.cc
+++ b/src/tint/sem/type.cc
@@ -25,6 +25,7 @@
 #include "src/tint/sem/pointer.h"
 #include "src/tint/sem/reference.h"
 #include "src/tint/sem/sampler.h"
+#include "src/tint/sem/struct.h"
 #include "src/tint/sem/texture.h"
 #include "src/tint/sem/u32.h"
 #include "src/tint/sem/vector.h"
@@ -172,6 +173,23 @@
     return IsAnyOf<Sampler, Texture>();
 }
 
+bool Type::HoldsAbstract() const {
+    return Switch(
+        this,  //
+        [&](const AbstractNumeric*) { return true; },
+        [&](const Vector* v) { return v->type()->HoldsAbstract(); },
+        [&](const Matrix* m) { return m->type()->HoldsAbstract(); },
+        [&](const Array* a) { return a->ElemType()->HoldsAbstract(); },
+        [&](const Struct* s) {
+            for (auto* m : s->Members()) {
+                if (m->Type()->HoldsAbstract()) {
+                    return true;
+                }
+            }
+            return false;
+        });
+}
+
 uint32_t Type::ConversionRank(const Type* from, const Type* to) {
     if (from->UnwrapRef() == to) {
         return 0;
diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h
index 434b913..ffa3d01 100644
--- a/src/tint/sem/type.h
+++ b/src/tint/sem/type.h
@@ -123,6 +123,10 @@
     /// @returns true if this type is a handle type
     bool is_handle() const;
 
+    /// @returns true if this type is an abstract-numeric or if the type holds an element that is an
+    /// abstract-numeric.
+    bool HoldsAbstract() const;
+
     /// kNoConversion is returned from ConversionRank() when the implicit conversion is not
     /// permitted.
     static constexpr uint32_t kNoConversion = 0xffffffffu;
diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc
index 02c76c2..ee79102 100644
--- a/src/tint/sem/type_test.cc
+++ b/src/tint/sem/type_test.cc
@@ -44,22 +44,38 @@
     const sem::Matrix* mat4x3_af = create<Matrix>(vec3_af, 4u);
     const sem::Reference* ref_u32 =
         create<Reference>(u32, ast::AddressSpace::kPrivate, ast::Access::kReadWrite);
-    const sem::Struct* str = create<Struct>(nullptr,
-                                            Sym("s"),
-                                            StructMemberList{
-                                                create<StructMember>(
-                                                    /* declaration */ nullptr,
-                                                    /* name */ Sym("x"),
-                                                    /* type */ f16,
-                                                    /* index */ 0u,
-                                                    /* offset */ 0u,
-                                                    /* align */ 4u,
-                                                    /* size */ 4u,
-                                                    /* location */ std::nullopt),
-                                            },
-                                            /* align*/ 4u,
-                                            /* size*/ 4u,
-                                            /* size_no_padding*/ 4u);
+    const sem::Struct* str_f16 = create<Struct>(nullptr,
+                                                Sym("str_f16"),
+                                                StructMemberList{
+                                                    create<StructMember>(
+                                                        /* declaration */ nullptr,
+                                                        /* name */ Sym("x"),
+                                                        /* type */ f16,
+                                                        /* index */ 0u,
+                                                        /* offset */ 0u,
+                                                        /* align */ 4u,
+                                                        /* size */ 4u,
+                                                        /* location */ std::nullopt),
+                                                },
+                                                /* align*/ 4u,
+                                                /* size*/ 4u,
+                                                /* size_no_padding*/ 4u);
+    const sem::Struct* str_af = create<Struct>(nullptr,
+                                               Sym("str_af"),
+                                               StructMemberList{
+                                                   create<StructMember>(
+                                                       /* declaration */ nullptr,
+                                                       /* name */ Sym("x"),
+                                                       /* type */ af,
+                                                       /* index */ 0u,
+                                                       /* offset */ 0u,
+                                                       /* align */ 4u,
+                                                       /* size */ 4u,
+                                                       /* location */ std::nullopt),
+                                               },
+                                               /* align*/ 4u,
+                                               /* size*/ 4u,
+                                               /* size_no_padding*/ 4u);
     const sem::Array* arr_i32 = create<Array>(
         /* element */ i32,
         /* count */ ConstantArrayCount{5u},
@@ -109,8 +125,15 @@
         /* size */ 5u * 64u,
         /* stride */ 5u * 64u,
         /* implicit_stride */ 5u * 64u);
-    const sem::Array* arr_str = create<Array>(
-        /* element */ str,
+    const sem::Array* arr_str_f16 = create<Array>(
+        /* element */ str_f16,
+        /* count */ ConstantArrayCount{5u},
+        /* align */ 4u,
+        /* size */ 5u * 4u,
+        /* stride */ 5u * 4u,
+        /* implicit_stride */ 5u * 4u);
+    const sem::Array* arr_str_af = create<Array>(
+        /* element */ str_af,
         /* count */ ConstantArrayCount{5u},
         /* align */ 4u,
         /* size */ 5u * 4u,
@@ -192,12 +215,12 @@
     EXPECT_TYPE(Type::ElementOf(mat2x4_f32), vec4_f32);
     EXPECT_TYPE(Type::ElementOf(mat4x2_f32), vec2_f32);
     EXPECT_TYPE(Type::ElementOf(mat4x3_f16), vec3_f16);
-    EXPECT_TYPE(Type::ElementOf(str), str);
+    EXPECT_TYPE(Type::ElementOf(str_f16), str_f16);
     EXPECT_TYPE(Type::ElementOf(arr_i32), i32);
     EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32);
     EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16), mat4x3_f16);
     EXPECT_TYPE(Type::ElementOf(arr_mat4x3_af), mat4x3_af);
-    EXPECT_TYPE(Type::ElementOf(arr_str), str);
+    EXPECT_TYPE(Type::ElementOf(arr_str_f16), str_f16);
 
     // With count
     uint32_t count = 42;
@@ -237,7 +260,7 @@
     EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), vec3_f16);
     EXPECT_EQ(count, 4u);
     count = 42;
-    EXPECT_TYPE(Type::ElementOf(str, &count), str);
+    EXPECT_TYPE(Type::ElementOf(str_f16, &count), str_f16);
     EXPECT_EQ(count, 1u);
     count = 42;
     EXPECT_TYPE(Type::ElementOf(arr_i32, &count), i32);
@@ -252,7 +275,7 @@
     EXPECT_TYPE(Type::ElementOf(arr_mat4x3_af, &count), mat4x3_af);
     EXPECT_EQ(count, 5u);
     count = 42;
-    EXPECT_TYPE(Type::ElementOf(arr_str, &count), str);
+    EXPECT_TYPE(Type::ElementOf(arr_str_f16, &count), str_f16);
     EXPECT_EQ(count, 5u);
 }
 
@@ -270,12 +293,12 @@
     EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32), f32);
     EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32), f32);
     EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16), f16);
-    EXPECT_TYPE(Type::DeepestElementOf(str), str);
+    EXPECT_TYPE(Type::DeepestElementOf(str_f16), str_f16);
     EXPECT_TYPE(Type::DeepestElementOf(arr_i32), i32);
     EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32), i32);
     EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16), f16);
     EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_af), af);
-    EXPECT_TYPE(Type::DeepestElementOf(arr_str), str);
+    EXPECT_TYPE(Type::DeepestElementOf(arr_str_f16), str_f16);
 
     // With count
     uint32_t count = 42;
@@ -315,7 +338,7 @@
     EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16, &count), f16);
     EXPECT_EQ(count, 12u);
     count = 42;
-    EXPECT_TYPE(Type::DeepestElementOf(str, &count), str);
+    EXPECT_TYPE(Type::DeepestElementOf(str_f16, &count), str_f16);
     EXPECT_EQ(count, 1u);
     count = 42;
     EXPECT_TYPE(Type::DeepestElementOf(arr_i32, &count), i32);
@@ -330,7 +353,7 @@
     EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_af, &count), af);
     EXPECT_EQ(count, 60u);
     count = 42;
-    EXPECT_TYPE(Type::DeepestElementOf(arr_str, &count), str);
+    EXPECT_TYPE(Type::DeepestElementOf(arr_str_f16, &count), str_f16);
     EXPECT_EQ(count, 5u);
 }
 
@@ -525,5 +548,39 @@
                 arr_mat4x3_f16);
 }
 
+TEST_F(TypeTest, HoldsAbstract) {
+    EXPECT_TRUE(af->HoldsAbstract());
+    EXPECT_TRUE(ai->HoldsAbstract());
+    EXPECT_FALSE(f32->HoldsAbstract());
+    EXPECT_FALSE(f16->HoldsAbstract());
+    EXPECT_FALSE(i32->HoldsAbstract());
+    EXPECT_FALSE(u32->HoldsAbstract());
+    EXPECT_FALSE(vec2_f32->HoldsAbstract());
+    EXPECT_FALSE(vec3_f32->HoldsAbstract());
+    EXPECT_FALSE(vec3_f16->HoldsAbstract());
+    EXPECT_FALSE(vec4_f32->HoldsAbstract());
+    EXPECT_FALSE(vec3_u32->HoldsAbstract());
+    EXPECT_FALSE(vec3_i32->HoldsAbstract());
+    EXPECT_TRUE(vec3_af->HoldsAbstract());
+    EXPECT_TRUE(vec3_ai->HoldsAbstract());
+    EXPECT_FALSE(mat2x4_f32->HoldsAbstract());
+    EXPECT_FALSE(mat3x4_f32->HoldsAbstract());
+    EXPECT_FALSE(mat4x2_f32->HoldsAbstract());
+    EXPECT_FALSE(mat4x3_f32->HoldsAbstract());
+    EXPECT_FALSE(mat4x3_f16->HoldsAbstract());
+    EXPECT_TRUE(mat4x3_af->HoldsAbstract());
+    EXPECT_FALSE(str_f16->HoldsAbstract());
+    EXPECT_TRUE(str_af->HoldsAbstract());
+    EXPECT_FALSE(arr_i32->HoldsAbstract());
+    EXPECT_TRUE(arr_ai->HoldsAbstract());
+    EXPECT_FALSE(arr_vec3_i32->HoldsAbstract());
+    EXPECT_TRUE(arr_vec3_ai->HoldsAbstract());
+    EXPECT_FALSE(arr_mat4x3_f16->HoldsAbstract());
+    EXPECT_FALSE(arr_mat4x3_f32->HoldsAbstract());
+    EXPECT_TRUE(arr_mat4x3_af->HoldsAbstract());
+    EXPECT_FALSE(arr_str_f16->HoldsAbstract());
+    EXPECT_TRUE(arr_str_af->HoldsAbstract());
+}
+
 }  // namespace
 }  // namespace tint::sem