resolver: Validate vector types

Fixed: tint:953
Change-Id: I3742680e49894a93db41219e512796ba9bdf036a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56778
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 307f807..8d9f265 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -324,8 +324,12 @@
     }
     if (auto* t = ty->As<ast::Vector>()) {
       if (auto* el = Type(t->type())) {
-        return builder_->create<sem::Vector>(const_cast<sem::Type*>(el),
-                                             t->size());
+        if (auto* vector = builder_->create<sem::Vector>(
+                const_cast<sem::Type*>(el), t->size())) {
+          if (ValidateVector(vector, t->source())) {
+            return vector;
+          }
+        }
       }
       return nullptr;
     }
@@ -333,10 +337,10 @@
       if (auto* el = Type(t->type())) {
         if (auto* column_type = builder_->create<sem::Vector>(
                 const_cast<sem::Type*>(el), t->rows())) {
-          if (auto* matrix_type =
+          if (auto* matrix =
                   builder_->create<sem::Matrix>(column_type, t->columns())) {
-            if (ValidateMatrix(matrix_type, t->source())) {
-              return matrix_type;
+            if (ValidateMatrix(matrix, t->source())) {
+              return matrix;
             }
           }
         }
@@ -2300,14 +2304,22 @@
   return true;
 }
 
-bool Resolver::ValidateMatrix(const sem::Matrix* matrix_type,
-                              const Source& source) {
-  if (!matrix_type->is_float_matrix()) {
+bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
+  if (!ty->type()->is_scalar()) {
+    AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
+             source);
+    return false;
+  }
+  return true;
+}
+
+bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
+  if (!ty->is_float_matrix()) {
     AddError("matrix element type must be 'f32'", source);
     return false;
   }
   return true;
-}  // namespace resolver
+}
 
 bool Resolver::ValidateMatrixConstructor(
     const ast::TypeConstructorExpression* ctor,
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index cf26a39..047ccc1 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -280,7 +280,7 @@
   bool ValidateGlobalVariable(const VariableInfo* var);
   bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
                                      const sem::Type* storage_type);
-  bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source);
+  bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
   bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
                                  const sem::Matrix* matrix_type);
   bool ValidateFunctionParameter(const ast::Function* func,
@@ -300,6 +300,7 @@
                                    const std::string& type_name,
                                    const sem::Type* rhs_type,
                                    const std::string& rhs_type_name);
+  bool ValidateVector(const sem::Vector* ty, const Source& source);
   bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
                                  const sem::Vector* vec_type);
   bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 70db243..d231774 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -1356,29 +1356,19 @@
   //   vec4<f32> foo
   // }
   // struct A {
-  //   vec3<struct b> mem
+  //   array<b, 3> mem
   // }
   // var c : A
   // c.mem[0].foo.yx
   //   -> vec2<f32>
   //
-  // MemberAccessor{
-  //   MemberAccessor{
-  //     ArrayAccessor{
-  //       MemberAccessor{
-  //         Identifier{c}
-  //         Identifier{mem}
-  //       }
-  //       ScalarConstructor{0}
-  //     }
-  //     Identifier{foo}
-  //   }
-  //   Identifier{yx}
+  // fn f() {
+  //   c.mem[0].foo
   // }
   //
 
   auto* stB = Structure("B", {Member("foo", ty.vec4<f32>())});
-  auto* stA = Structure("A", {Member("mem", ty.vec(ty.Of(stB), 3))});
+  auto* stA = Structure("A", {Member("mem", ty.array(ty.Of(stB), 3))});
   Global("c", ty.Of(stA), ast::StorageClass::kPrivate);
 
   auto* mem = MemberAccessor(
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index 68c307d..a0a17c2 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -156,6 +156,9 @@
 template <typename T>
 using mat4x4 = mat<4, 4, T>;
 
+template <int N, typename T>
+struct array {};
+
 template <typename TO, int ID = 0>
 struct alias {};
 
@@ -384,6 +387,43 @@
   }
 };
 
+/// Helper for building array types and expressions
+template <int N, typename T>
+struct DataType<array<N, T>> {
+  /// true as arrays are a composite type
+  static constexpr bool is_composite = true;
+
+  /// @param b the ProgramBuilder
+  /// @return a new AST array type
+  static inline ast::Type* AST(ProgramBuilder& b) {
+    return b.ty.array(DataType<T>::AST(b), N);
+  }
+  /// @param b the ProgramBuilder
+  /// @return the semantic array type
+  static inline sem::Type* Sem(ProgramBuilder& b) {
+    return b.create<sem::Array>(DataType<T>::Sem(b), N);
+  }
+  /// @param b the ProgramBuilder
+  /// @param elem_value the value each element in the array will be initialized
+  /// with
+  /// @return a new AST array value expression
+  static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
+    return b.Construct(AST(b), ExprArgs(b, elem_value));
+  }
+
+  /// @param b the ProgramBuilder
+  /// @param elem_value the value each element will be initialized with
+  /// @return the list of expressions that are used to construct the array
+  static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
+                                             int elem_value) {
+    ast::ExpressionList args;
+    for (int i = 0; i < N; i++) {
+      args.emplace_back(DataType<T>::Expr(b, elem_value));
+    }
+    return args;
+  }
+};
+
 /// Struct of all creation pointer types
 struct CreatePtrs {
   /// ast node type create function
diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc
index fd73e05..52925b4 100644
--- a/src/resolver/type_validation_test.cc
+++ b/src/resolver/type_validation_test.cc
@@ -42,6 +42,8 @@
 using mat3x3 = builder::mat3x3<T>;
 template <typename T>
 using mat4x4 = builder::mat4x4<T>;
+template <int N, typename T>
+using array = builder::array<N, T>;
 template <typename T>
 using alias = builder::alias<T>;
 template <typename T>
@@ -751,6 +753,126 @@
 
 }  // namespace StorageTextureTests
 
+namespace MatrixTests {
+struct Params {
+  uint32_t columns;
+  uint32_t rows;
+  builder::ast_type_func_ptr elem_ty;
+};
+
+template <typename T>
+constexpr Params ParamsFor(uint32_t columns, uint32_t rows) {
+  return Params{columns, rows, DataType<T>::AST};
+}
+
+using ValidMatrixTypes = ResolverTestWithParam<Params>;
+TEST_P(ValidMatrixTypes, Okay) {
+  // var a : matNxM<EL_TY>;
+  auto& params = GetParam();
+  Global("a", ty.mat(params.elem_ty(*this), params.columns, params.rows),
+         ast::StorageClass::kPrivate);
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+                         ValidMatrixTypes,
+                         testing::Values(ParamsFor<f32>(2, 2),
+                                         ParamsFor<f32>(2, 3),
+                                         ParamsFor<f32>(2, 4),
+                                         ParamsFor<f32>(3, 2),
+                                         ParamsFor<f32>(3, 3),
+                                         ParamsFor<f32>(3, 4),
+                                         ParamsFor<f32>(4, 2),
+                                         ParamsFor<f32>(4, 3),
+                                         ParamsFor<f32>(4, 4),
+                                         ParamsFor<alias<f32>>(4, 2),
+                                         ParamsFor<alias<f32>>(4, 3),
+                                         ParamsFor<alias<f32>>(4, 4)));
+
+using InvalidMatrixElementTypes = ResolverTestWithParam<Params>;
+TEST_P(InvalidMatrixElementTypes, InvalidElementType) {
+  // var a : matNxM<EL_TY>;
+  auto& params = GetParam();
+  Global("a",
+         ty.mat(Source{{12, 34}}, params.elem_ty(*this), params.columns,
+                params.rows),
+         ast::StorageClass::kPrivate);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+                         InvalidMatrixElementTypes,
+                         testing::Values(ParamsFor<bool>(4, 2),
+                                         ParamsFor<i32>(4, 3),
+                                         ParamsFor<u32>(4, 4),
+                                         ParamsFor<vec2<f32>>(2, 2),
+                                         ParamsFor<vec3<i32>>(2, 3),
+                                         ParamsFor<vec4<u32>>(2, 4),
+                                         ParamsFor<mat2x2<f32>>(3, 2),
+                                         ParamsFor<mat3x3<f32>>(3, 3),
+                                         ParamsFor<mat4x4<f32>>(3, 4),
+                                         ParamsFor<array<2, f32>>(4, 2)));
+}  // namespace MatrixTests
+
+namespace VectorTests {
+struct Params {
+  uint32_t width;
+  builder::ast_type_func_ptr elem_ty;
+};
+
+template <typename T>
+constexpr Params ParamsFor(uint32_t width) {
+  return Params{width, DataType<T>::AST};
+}
+
+using ValidVectorTypes = ResolverTestWithParam<Params>;
+TEST_P(ValidVectorTypes, Okay) {
+  // var a : vecN<EL_TY>;
+  auto& params = GetParam();
+  Global("a", ty.vec(params.elem_ty(*this), params.width),
+         ast::StorageClass::kPrivate);
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+                         ValidVectorTypes,
+                         testing::Values(ParamsFor<bool>(2),
+                                         ParamsFor<f32>(2),
+                                         ParamsFor<i32>(2),
+                                         ParamsFor<u32>(2),
+                                         ParamsFor<bool>(3),
+                                         ParamsFor<f32>(3),
+                                         ParamsFor<i32>(3),
+                                         ParamsFor<u32>(3),
+                                         ParamsFor<bool>(4),
+                                         ParamsFor<f32>(4),
+                                         ParamsFor<i32>(4),
+                                         ParamsFor<u32>(4),
+                                         ParamsFor<alias<bool>>(4),
+                                         ParamsFor<alias<f32>>(4),
+                                         ParamsFor<alias<i32>>(4),
+                                         ParamsFor<alias<u32>>(4)));
+
+using InvalidVectorElementTypes = ResolverTestWithParam<Params>;
+TEST_P(InvalidVectorElementTypes, InvalidElementType) {
+  // var a : vecN<EL_TY>;
+  auto& params = GetParam();
+  Global("a", ty.vec(Source{{12, 34}}, params.elem_ty(*this), params.width),
+         ast::StorageClass::kPrivate);
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: vector element type must be 'bool', 'f32', 'i32' or 'u32'");
+}
+INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
+                         InvalidVectorElementTypes,
+                         testing::Values(ParamsFor<vec2<f32>>(2),
+                                         ParamsFor<vec3<i32>>(2),
+                                         ParamsFor<vec4<u32>>(2),
+                                         ParamsFor<mat2x2<f32>>(2),
+                                         ParamsFor<mat3x3<f32>>(2),
+                                         ParamsFor<mat4x4<f32>>(2),
+                                         ParamsFor<array<2, f32>>(2)));
+}  // namespace VectorTests
+
 }  // namespace
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index c688ad7..fee3f57 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -2007,17 +2007,6 @@
 
 using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
 
-TEST_F(MatrixConstructorTest, Expr_Constructor_Matrix_NotF32) {
-  // m2x2<i32>()
-  SetSource(Source::Location({12, 34}));
-  auto* tc = mat2x2<i32>(
-      create<ast::TypeConstructorExpression>(ty.mat2x2<i32>(), ExprList()));
-  WrapInFunction(tc);
-
-  EXPECT_FALSE(r()->Resolve());
-  EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
-}
-
 TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
   // matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments