[tint][ir][val] Check vector and matrix element types

Fixes: 380903152
Fixes: 380898781
Change-Id: I1b66805a3186496f55b9150c573ac2d970a4e0c4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/216615
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index d56cf0b..7677f4c 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -83,6 +83,7 @@
 #include "src/tint/lang/core/type/f32.h"
 #include "src/tint/lang/core/type/i32.h"
 #include "src/tint/lang/core/type/i8.h"
+#include "src/tint/lang/core/type/matrix.h"
 #include "src/tint/lang/core/type/memory_view.h"
 #include "src/tint/lang/core/type/pointer.h"
 #include "src/tint/lang/core/type/reference.h"
@@ -936,7 +937,8 @@
                                  size_t num_results,
                                  size_t num_operands);
 
-    /// Checks that @p type does not use any types that are prohibited by the target capabilities.
+    /// Checks that @p type is allowed by the spec, and does not use any types that are prohibited
+    /// by the target capabilities.
     /// @param type the type
     /// @param diag a function that creates an error diagnostic for the source of the type
     /// @param ignore_caps a set of capabilities to ignore for this check
@@ -1778,6 +1780,20 @@
                 }
                 return true;
             },
+            [&](const core::type::Vector* v) {
+                if (!v->Type()->IsScalar()) {
+                    diag() << "vector elements must be scalars";
+                    return false;
+                }
+                return true;
+            },
+            [&](const core::type::Matrix* m) {
+                if (!m->Type()->IsFloatScalar()) {
+                    diag() << "matrix elements must be float scalars";
+                    return false;
+                }
+                return true;
+            },
             [](Default) { return true; });
     };
 
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 44f1f5c..57f72f0 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -4907,6 +4907,106 @@
 )");
 }
 
+TEST_F(IR_ValidatorTest, Type_VectorElements) {
+    auto* f = b.Function("my_func", ty.void_());
+
+    b.Append(f->Block(), [&] {
+        b.Var("u32_valid", AddressSpace::kFunction, ty.vec4(ty.u32()));
+        b.Var("i32_valid", AddressSpace::kFunction, ty.vec4(ty.i32()));
+        b.Var("bool_valid", AddressSpace::kFunction, ty.vec2(ty.bool_()));
+        b.Var("f16_valid", AddressSpace::kFunction, ty.vec3(ty.f16()));
+        b.Var("f32_valid", AddressSpace::kFunction, ty.vec3(ty.f32()));
+        b.Var("void_invalid", AddressSpace::kFunction, ty.vec2(ty.void_()));
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(), R"(:8:5 error: var: vector elements must be scalars
+    %void_invalid:ptr<function, vec2<void>, read_write> = var
+    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+  $B1: {
+  ^^^
+
+note: # Disassembly
+%my_func = func():void {
+  $B1: {
+    %u32_valid:ptr<function, vec4<u32>, read_write> = var
+    %i32_valid:ptr<function, vec4<i32>, read_write> = var
+    %bool_valid:ptr<function, vec2<bool>, read_write> = var
+    %f16_valid:ptr<function, vec3<f16>, read_write> = var
+    %f32_valid:ptr<function, vec3<f32>, read_write> = var
+    %void_invalid:ptr<function, vec2<void>, read_write> = var
+    ret
+  }
+}
+)");
+}
+
+TEST_F(IR_ValidatorTest, Type_MatrixElements) {
+    auto* f = b.Function("my_func", ty.void_());
+
+    b.Append(f->Block(), [&] {
+        b.Var("u32_invalid", AddressSpace::kFunction, ty.mat2x2(ty.u32()));
+        b.Var("i32_invalid", AddressSpace::kFunction, ty.mat3x2(ty.i32()));
+        b.Var("bool_invalid", AddressSpace::kFunction, ty.mat4x2(ty.bool_()));
+        b.Var("f16_valid", AddressSpace::kFunction, ty.mat2x3(ty.f16()));
+        b.Var("f32_valid", AddressSpace::kFunction, ty.mat4x4(ty.f32()));
+        b.Var("void_invalid", AddressSpace::kFunction, ty.mat3x3(ty.void_()));
+        b.Return(f);
+    });
+
+    auto res = ir::Validate(mod);
+    ASSERT_NE(res, Success);
+    EXPECT_EQ(res.Failure().reason.Str(), R"(:3:5 error: var: matrix elements must be float scalars
+    %u32_invalid:ptr<function, mat2x2<u32>, read_write> = var
+    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+  $B1: {
+  ^^^
+
+:4:5 error: var: matrix elements must be float scalars
+    %i32_invalid:ptr<function, mat3x2<i32>, read_write> = var
+    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+  $B1: {
+  ^^^
+
+:5:5 error: var: matrix elements must be float scalars
+    %bool_invalid:ptr<function, mat4x2<bool>, read_write> = var
+    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+  $B1: {
+  ^^^
+
+:8:5 error: var: matrix elements must be float scalars
+    %void_invalid:ptr<function, mat3x3<void>, read_write> = var
+    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+:2:3 note: in block
+  $B1: {
+  ^^^
+
+note: # Disassembly
+%my_func = func():void {
+  $B1: {
+    %u32_invalid:ptr<function, mat2x2<u32>, read_write> = var
+    %i32_invalid:ptr<function, mat3x2<i32>, read_write> = var
+    %bool_invalid:ptr<function, mat4x2<bool>, read_write> = var
+    %f16_valid:ptr<function, mat2x3<f16>, read_write> = var
+    %f32_valid:ptr<function, mat4x4<f32>, read_write> = var
+    %void_invalid:ptr<function, mat3x3<void>, read_write> = var
+    ret
+  }
+}
+)");
+}
+
 TEST_F(IR_ValidatorTest, Var_RootBlock_NullResult) {
     auto* v = mod.CreateInstruction<ir::Var>(nullptr);
     v->SetInitializer(b.Constant(0_i));
diff --git a/src/tint/lang/core/type/type.cc b/src/tint/lang/core/type/type.cc
index cf77682..10ef6a5 100644
--- a/src/tint/lang/core/type/type.cc
+++ b/src/tint/lang/core/type/type.cc
@@ -34,6 +34,7 @@
 #include "src/tint/lang/core/type/f16.h"
 #include "src/tint/lang/core/type/f32.h"
 #include "src/tint/lang/core/type/i32.h"
+#include "src/tint/lang/core/type/i8.h"
 #include "src/tint/lang/core/type/matrix.h"
 #include "src/tint/lang/core/type/pointer.h"
 #include "src/tint/lang/core/type/reference.h"
@@ -41,6 +42,7 @@
 #include "src/tint/lang/core/type/struct.h"
 #include "src/tint/lang/core/type/texture.h"
 #include "src/tint/lang/core/type/u32.h"
+#include "src/tint/lang/core/type/u8.h"
 #include "src/tint/lang/core/type/vector.h"
 #include "src/tint/utils/rtti/switch.h"
 
@@ -109,7 +111,7 @@
 }
 
 bool Type::IsIntegerScalar() const {
-    return IsAnyOf<U32, I32>();
+    return IsAnyOf<U32, I32, U8, I8>();
 }
 
 bool Type::IsIntegerVector() const {