Test that entry point IO attributes are of valid types

BUG=tint:773

Change-Id: I94e8624647c645efe7ed558caa3d3bd05dd72f63
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50260
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
Auto-Submit: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/resolver/entry_point_validation_test.cc b/src/resolver/entry_point_validation_test.cc
index f872844..3a7eef3 100644
--- a/src/resolver/entry_point_validation_test.cc
+++ b/src/resolver/entry_point_validation_test.cc
@@ -23,13 +23,14 @@
 #include "gmock/gmock.h"
 
 namespace tint {
+namespace resolver {
 namespace {
 
-class ResolverEntryPointValidationTest : public resolver::TestHelper,
+class ResolverEntryPointValidationTest : public TestHelper,
                                          public testing::Test {};
 
 TEST_F(ResolverEntryPointValidationTest, ReturnTypeAttribute_Location) {
-  // [[stage(vertex)]]
+  // [[stage(fragment)]]
   // fn main() -> [[location(0)]] f32 { return 1.0; }
   Func(Source{{12, 34}}, "main", {}, ty.f32(), {Return(1.0f)},
        {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
@@ -514,5 +515,118 @@
             "in its return type");
 }
 
+namespace TypeValidationTests {
+struct Params {
+  create_ast_type_func_ptr create_ast_type;
+  bool is_valid;
+};
+
+using TypeValidationTest = resolver::ResolverTestWithParam<Params>;
+
+static constexpr Params cases[] = {
+    {ast_f32, true},
+    {ast_i32, true},
+    {ast_u32, true},
+    {ast_bool, false},
+    {ast_vec2<ast_f32>, true},
+    {ast_vec3<ast_f32>, true},
+    {ast_vec4<ast_f32>, true},
+    {ast_mat2x2<ast_f32>, false},
+    {ast_mat2x2<ast_i32>, false},
+    {ast_mat2x2<ast_u32>, false},
+    {ast_mat2x2<ast_bool>, false},
+    {ast_mat3x3<ast_f32>, false},
+    {ast_mat3x3<ast_i32>, false},
+    {ast_mat3x3<ast_u32>, false},
+    {ast_mat3x3<ast_bool>, false},
+    {ast_mat4x4<ast_f32>, false},
+    {ast_mat4x4<ast_i32>, false},
+    {ast_mat4x4<ast_u32>, false},
+    {ast_mat4x4<ast_bool>, false},
+    {ast_alias<ast_f32>, true},
+    {ast_alias<ast_i32>, true},
+    {ast_alias<ast_u32>, true},
+    {ast_alias<ast_bool>, false},
+};
+
+TEST_P(TypeValidationTest, BareInputs) {
+  // [[stage(fragment)]]
+  // fn main([[location(0)]] a : *) {}
+  auto params = GetParam();
+  auto* a = Param("a", params.create_ast_type(ty), {Location(0)});
+  Func(Source{{12, 34}}, "main", {a}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kFragment)});
+
+  if (params.is_valid) {
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+  } else {
+    EXPECT_FALSE(r()->Resolve());
+  }
+}
+
+TEST_P(TypeValidationTest, StructInputs) {
+  // struct Input {
+  //   [[location(0)]] a : *;
+  // };
+  // [[stage(fragment)]]
+  // fn main(a : Input) {}
+  auto params = GetParam();
+  auto* input = Structure(
+      "Input", {Member("a", params.create_ast_type(ty), {Location(0)})});
+  auto* a = Param("a", input, {});
+  Func(Source{{12, 34}}, "main", {a}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kFragment)});
+
+  if (params.is_valid) {
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+  } else {
+    EXPECT_FALSE(r()->Resolve());
+  }
+}
+
+TEST_P(TypeValidationTest, BareOutputs) {
+  // [[stage(fragment)]]
+  // fn main() -> [[location(0)]] * {
+  //   return *();
+  // }
+  auto params = GetParam();
+  Func(Source{{12, 34}}, "main", {}, params.create_ast_type(ty),
+       {Return(Construct(params.create_ast_type(ty)))},
+       {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+
+  if (params.is_valid) {
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+  } else {
+    EXPECT_FALSE(r()->Resolve());
+  }
+}
+
+TEST_P(TypeValidationTest, StructOutputs) {
+  // struct Output {
+  //   [[location(0)]] a : *;
+  // };
+  // [[stage(fragment)]]
+  // fn main() -> Output {
+  //   return Output();
+  // }
+  auto params = GetParam();
+  auto* output = Structure(
+      "Output", {Member("a", params.create_ast_type(ty), {Location(0)})});
+  Func(Source{{12, 34}}, "main", {}, output, {Return(Construct(output))},
+       {Stage(ast::PipelineStage::kFragment)});
+
+  if (params.is_valid) {
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+  } else {
+    EXPECT_FALSE(r()->Resolve());
+  }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverEntryPointValidationTest,
+                         TypeValidationTest,
+                         testing::ValuesIn(cases));
+
+}  // namespace TypeValidationTests
+
 }  // namespace
+}  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index a155ff4..32273e2 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -832,6 +832,19 @@
             diagnostics_.add_error(err, source);
             return false;
           }
+
+          // Check that all user defined attributes are numeric scalars, vectors
+          // of numeric scalars.
+          // Testing for being a struct is handled by the if portion above.
+          if (!pipeline_io_attribute->Is<ast::BuiltinDecoration>()) {
+            if (!Canonical(ty)->is_numeric_scalar_or_vector()) {
+              diagnostics_.add_error(
+                  "User defined entry point IO types must be a numeric scalar, "
+                  "a numeric vector, or a structure",
+                  source);
+              return false;
+            }
+          }
         }
 
         return true;
diff --git a/src/sem/type.cc b/src/sem/type.cc
index 947b149..589688f 100644
--- a/src/sem/type.cc
+++ b/src/sem/type.cc
@@ -132,6 +132,15 @@
   return Is<Bool>() || is_bool_vector();
 }
 
+bool Type::is_numeric_vector() const {
+  return Is<Vector>(
+      [](const Vector* v) { return v->type()->is_numeric_scalar(); });
+}
+
+bool Type::is_numeric_scalar_or_vector() const {
+  return is_numeric_scalar() || is_numeric_vector();
+}
+
 bool Type::is_handle() const {
   return IsAnyOf<Sampler, Texture>();
 }
diff --git a/src/sem/type.h b/src/sem/type.h
index 49da00c..ca7bdd4 100644
--- a/src/sem/type.h
+++ b/src/sem/type.h
@@ -125,6 +125,10 @@
   bool is_bool_vector() const;
   /// @returns true if this type is boolean scalar or vector
   bool is_bool_scalar_or_vector() const;
+  /// @returns true if this type is a numeric vector
+  bool is_numeric_vector() const;
+  /// @returns true if this type is a numeric scale or vector
+  bool is_numeric_scalar_or_vector() const;
   /// @returns true if this type is a handle type
   bool is_handle() const;
 
diff --git a/src/transform/canonicalize_entry_point_io_test.cc b/src/transform/canonicalize_entry_point_io_test.cc
index 2f4b4ab..876c21d 100644
--- a/src/transform/canonicalize_entry_point_io_test.cc
+++ b/src/transform/canonicalize_entry_point_io_test.cc
@@ -508,13 +508,13 @@
 struct VertexOutput {
   [[location(1)]] b : u32;
   [[builtin(position)]] pos : vec4<f32>;
-  [[location(3)]] d : bool;
+  [[location(3)]] d : u32;
   [[location(0)]] a : f32;
   [[location(2)]] c : i32;
 };
 
 struct FragmentInputExtra {
-  [[location(3)]] d : bool;
+  [[location(3)]] d : u32;
   [[builtin(position)]] pos : vec4<f32>;
   [[location(0)]] a : f32;
 };
@@ -536,13 +536,13 @@
 struct VertexOutput {
   b : u32;
   pos : vec4<f32>;
-  d : bool;
+  d : u32;
   a : f32;
   c : i32;
 };
 
 struct FragmentInputExtra {
-  d : bool;
+  d : u32;
   pos : vec4<f32>;
   a : f32;
 };
@@ -555,7 +555,7 @@
   [[location(2)]]
   c : i32;
   [[location(3)]]
-  d : bool;
+  d : u32;
   [[builtin(position)]]
   pos : vec4<f32>;
 };
@@ -574,7 +574,7 @@
   [[location(2)]]
   c : i32;
   [[location(3)]]
-  d : bool;
+  d : u32;
   [[builtin(position)]]
   pos : vec4<f32>;
   [[builtin(front_facing)]]