validation: validate builtin pipeline stage and Input/Output

Bug: tint:957
Change-Id: I5f509e61501b39f2a0b3bc10a204ae1f39a0d460
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57105
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index 3dc70ab..3c9b935 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -1059,11 +1059,11 @@
 
 TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) {
   auto* in_var0 =
-      Param("in_var0", ty.u32(), {Builtin(ast::Builtin::kInstanceIndex)});
-  auto* in_var1 = Param("in_var1", ty.u32(), {Location(0u)});
-  Func("foo", {in_var0, in_var1}, ty.u32(), {Return("in_var1")},
+      Param("in_var0", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)});
+  auto* in_var1 = Param("in_var1", ty.f32(), {Location(0u)});
+  Func("foo", {in_var0, in_var1}, ty.f32(), {Return("in_var1")},
        {Stage(ast::PipelineStage::kFragment)},
-       {Builtin(ast::Builtin::kSampleMask)});
+       {Builtin(ast::Builtin::kFragDepth)});
   Inspector& inspector = Build();
 
   auto result = inspector.GetEntryPoints();
@@ -1075,7 +1075,7 @@
   EXPECT_EQ("in_var1", result[0].input_variables[0].name);
   EXPECT_TRUE(result[0].input_variables[0].has_location_decoration);
   EXPECT_EQ(0u, result[0].input_variables[0].location_decoration);
-  EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[0].component_type);
+  EXPECT_EQ(ComponentType::kFloat, result[0].input_variables[0].component_type);
 
   ASSERT_EQ(0u, result[0].output_variables.size());
 }
diff --git a/src/resolver/builtins_validation_test.cc b/src/resolver/builtins_validation_test.cc
index 6c06ce0..46a2b66 100644
--- a/src/resolver/builtins_validation_test.cc
+++ b/src/resolver/builtins_validation_test.cc
@@ -16,9 +16,209 @@
 #include "src/resolver/resolver_test_helper.h"
 
 namespace tint {
+namespace resolver {
 namespace {
+
+template <typename T>
+using DataType = builder::DataType<T>;
+template <typename T>
+using vec2 = builder::vec2<T>;
+template <typename T>
+using vec3 = builder::vec3<T>;
+template <typename T>
+using vec4 = builder::vec4<T>;
+template <typename T>
+using f32 = builder::f32;
+using i32 = builder::i32;
+using u32 = builder::u32;
+
 class ResolverBuiltinsValidationTest : public resolver::TestHelper,
                                        public testing::Test {};
+namespace TypeTemp {
+struct Params {
+  builder::ast_type_func_ptr type;
+  ast::Builtin builtin;
+  ast::PipelineStage stage;
+  bool is_valid;
+};
+
+template <typename T>
+constexpr Params ParamsFor(ast::Builtin builtin,
+                           ast::PipelineStage stage,
+                           bool is_valid) {
+  return Params{DataType<T>::AST, builtin, stage, is_valid};
+}
+static constexpr Params cases[] = {
+    ParamsFor<u32>(ast::Builtin::kVertexIndex,
+                   ast::PipelineStage::kVertex,
+                   true),
+    ParamsFor<u32>(ast::Builtin::kVertexIndex,
+                   ast::PipelineStage::kFragment,
+                   false),
+    ParamsFor<u32>(ast::Builtin::kVertexIndex,
+                   ast::PipelineStage::kCompute,
+                   false),
+
+    ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+                   ast::PipelineStage::kVertex,
+                   true),
+    ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+                   ast::PipelineStage::kFragment,
+                   false),
+    ParamsFor<u32>(ast::Builtin::kInstanceIndex,
+                   ast::PipelineStage::kCompute,
+                   false),
+
+    ParamsFor<bool>(ast::Builtin::kFrontFacing,
+                    ast::PipelineStage::kVertex,
+                    false),
+    ParamsFor<bool>(ast::Builtin::kFrontFacing,
+                    ast::PipelineStage::kFragment,
+                    true),
+    ParamsFor<bool>(ast::Builtin::kFrontFacing,
+                    ast::PipelineStage::kCompute,
+                    false),
+
+    ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+                         ast::PipelineStage::kVertex,
+                         false),
+    ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+                         ast::PipelineStage::kFragment,
+                         false),
+    ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
+                         ast::PipelineStage::kCompute,
+                         true),
+
+    ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+                   ast::PipelineStage::kVertex,
+                   false),
+    ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+                   ast::PipelineStage::kFragment,
+                   false),
+    ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
+                   ast::PipelineStage::kCompute,
+                   true),
+
+    ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+                         ast::PipelineStage::kVertex,
+                         false),
+    ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+                         ast::PipelineStage::kFragment,
+                         false),
+    ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
+                         ast::PipelineStage::kCompute,
+                         true),
+
+    ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+                         ast::PipelineStage::kVertex,
+                         false),
+    ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+                         ast::PipelineStage::kFragment,
+                         false),
+    ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
+                         ast::PipelineStage::kCompute,
+                         true),
+
+    ParamsFor<u32>(ast::Builtin::kSampleIndex,
+                   ast::PipelineStage::kVertex,
+                   false),
+    ParamsFor<u32>(ast::Builtin::kSampleIndex,
+                   ast::PipelineStage::kFragment,
+                   true),
+    ParamsFor<u32>(ast::Builtin::kSampleIndex,
+                   ast::PipelineStage::kCompute,
+                   false),
+
+    ParamsFor<u32>(ast::Builtin::kSampleMask,
+                   ast::PipelineStage::kVertex,
+                   false),
+    ParamsFor<u32>(ast::Builtin::kSampleMask,
+                   ast::PipelineStage::kFragment,
+                   true),
+    ParamsFor<u32>(ast::Builtin::kSampleMask,
+                   ast::PipelineStage::kCompute,
+                   false),
+};
+
+using ResolverBuiltinsStageTest = ResolverTestWithParam<Params>;
+TEST_P(ResolverBuiltinsStageTest, All_input) {
+  const Params& params = GetParam();
+
+  auto* p = Global("p", ty.vec4<f32>(), ast::StorageClass::kPrivate);
+  auto* input =
+      Param("input", params.type(*this),
+            ast::DecorationList{Builtin(Source{{12, 34}}, params.builtin)});
+  switch (params.stage) {
+    case ast::PipelineStage::kVertex:
+      Func("main", {input}, ty.vec4<f32>(), {Return(p)},
+           {Stage(ast::PipelineStage::kVertex)},
+           {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
+      break;
+    case ast::PipelineStage::kFragment:
+      Func("main", {input}, ty.void_(), {},
+           {Stage(ast::PipelineStage::kFragment)}, {});
+      break;
+    case ast::PipelineStage::kCompute:
+      Func("main", {input}, ty.void_(), {},
+           ast::DecorationList{Stage(ast::PipelineStage::kCompute),
+                               WorkgroupSize(1)});
+      break;
+    default:
+      break;
+  }
+
+  if (params.is_valid) {
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+  } else {
+    std::stringstream err;
+    err << "12:34 error: builtin(" << params.builtin << ")";
+    err << " cannot be used in input of " << params.stage << " pipeline stage";
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(), err.str());
+  }
+}
+INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
+                         ResolverBuiltinsStageTest,
+                         testing::ValuesIn(cases));
+
+TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) {
+  // [[stage(fragment)]]
+  // fn fs_main(
+  //   [[builtin(kFragDepth)]] fd: f32,
+  // ) -> [[location(0)]] f32 { return 1.0; }
+  auto* fd = Param(
+      "fd", ty.f32(),
+      ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
+  Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)},
+       ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
+       {Location(0)});
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: builtin(frag_depth) cannot be used in input of "
+            "fragment pipeline stage");
+}
+
+TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) {
+  // Struct MyInputs {
+  //   [[builtin(front_facing)]] ff: bool;
+  // };
+  // [[stage(fragment)]]
+  // fn fragShader(arg: MyInputs) -> [[location(0)]] f32 { return 1.0; }
+
+  auto* s = Structure(
+      "MyInputs", {Member("frag_depth", ty.f32(),
+                          ast::DecorationList{Builtin(
+                              Source{{12, 34}}, ast::Builtin::kFragDepth)})});
+
+  Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
+       {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: builtin(frag_depth) cannot be used in input of fragment "
+      "pipeline stage\nnote: while analysing entry point fragShader");
+}
+}  // namespace TypeTemp
 
 TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) {
   // struct MyInputs {
@@ -170,15 +370,12 @@
 
 TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail) {
   // [[stage(fragment)]]
-  // fn fs_main(
-  //   [[builtin(kFragDepth)]] fd: f32,
-  // ) -> [[location(0)]] f32 { return 1.0; }
-  auto* fd = Param(
-      "fd", ty.i32(),
+  // fn fs_main() -> [[builtin(kFragDepth)]] f32 { var fd: i32; return fd; }
+  auto* fd = Var("fd", ty.i32());
+  Func(
+      "fs_main", {}, ty.i32(), {Decl(fd), Return(fd)},
+      ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
       ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
-  Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)},
-       ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
-       {Location(0)});
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
             "12:34 error: store type of builtin(frag_depth) must be 'f32'");
@@ -227,44 +424,43 @@
   // fn fs_main(
   //   [[builtin(kPosition)]] p: vec4<f32>,
   //   [[builtin(front_facing)]] ff: bool,
-  //   [[builtin(frag_depth)]] fd: f32,
   //   [[builtin(sample_index)]] si: u32,
   //   [[builtin(sample_mask)]] sm : u32
-  // ) -> [[location(0)]] f32 { return 1.0; }
+  // ) -> [[builtin(frag_depth)]] f32 { var fd: f32; return fd; }
   auto* p = Param("p", ty.vec4<f32>(),
                   ast::DecorationList{Builtin(ast::Builtin::kPosition)});
   auto* ff = Param("ff", ty.bool_(),
                    ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)});
-  auto* fd = Param("fd", ty.f32(),
-                   ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
   auto* si = Param("si", ty.u32(),
                    ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)});
   auto* sm = Param("sm", ty.u32(),
                    ast::DecorationList{Builtin(ast::Builtin::kSampleMask)});
-  Func(
-      "fs_main", ast::VariableList{p, ff, fd, si, sm}, ty.f32(), {Return(1.0f)},
-      ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, {Location(0)});
+  auto* var_fd = Var("fd", ty.f32());
+  Func("fs_main", ast::VariableList{p, ff, si, sm}, ty.f32(),
+       {Decl(var_fd), Return(var_fd)},
+       ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
+       ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
 TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass) {
   // [[stage(vertex)]]
   // fn main(
-  //   [[builtin(kVertexIndex)]] vi : u32,
-  //   [[builtin(kInstanceIndex)]] ii : u32,
-  //   [[builtin(kPosition)]] p :vec4<f32>
-  // ) {}
+  //   [[builtin(vertex_index)]] vi : u32,
+  //   [[builtin(instance_index)]] ii : u32,
+  // ) -> [[builtin(position)]] vec4<f32> { var p :vec4<f32>; return p; }
   auto* vi = Param("vi", ty.u32(),
                    ast::DecorationList{
                        Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)});
-  auto* p = Param("p", ty.vec4<f32>(),
-                  ast::DecorationList{Builtin(ast::Builtin::kPosition)});
+
   auto* ii = Param("ii", ty.u32(),
                    ast::DecorationList{Builtin(Source{{12, 34}},
                                                ast::Builtin::kInstanceIndex)});
-  Func("main", ast::VariableList{vi, ii, p}, ty.vec4<f32>(),
+  auto* p = Var("p", ty.vec4<f32>());
+  Func("main", ast::VariableList{vi, ii}, ty.vec4<f32>(),
        {
-           Return(Expr(p)),
+           Decl(p),
+           Return(p),
        },
        ast::DecorationList{Stage(ast::PipelineStage::kVertex)},
        ast::DecorationList{Builtin(ast::Builtin::kPosition)});
@@ -369,7 +565,6 @@
 TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) {
   // Struct MyInputs {
   //   [[builtin(kPosition)]] p: vec4<f32>;
-  //   [[builtin(front_facing)]] ff: bool;
   //   [[builtin(frag_depth)]] fd: f32;
   //   [[builtin(sample_index)]] si: u32;
   //   [[builtin(sample_mask)]] sm : u32;;
@@ -383,8 +578,6 @@
               ast::DecorationList{Builtin(ast::Builtin::kPosition)}),
        Member("front_facing", ty.bool_(),
               ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)}),
-       Member("frag_depth", ty.f32(),
-              ast::DecorationList{Builtin(ast::Builtin::kFragDepth)}),
        Member("sample_index", ty.u32(),
               ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)}),
        Member("sample_mask", ty.u32(),
@@ -1006,4 +1199,5 @@
                                            "pack2x16float"));
 
 }  // namespace
+}  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc
index ba261a8..61f8d36 100644
--- a/src/resolver/decoration_validation_test.cc
+++ b/src/resolver/decoration_validation_test.cc
@@ -145,7 +145,8 @@
   } else {
     EXPECT_FALSE(r()->Resolve()) << r()->error();
     EXPECT_EQ(r()->error(),
-              "error: decoration is not valid for function parameters");
+              "error: decoration is not valid for non-entry point function "
+              "parameters");
   }
 }
 INSTANTIATE_TEST_SUITE_P(
@@ -244,7 +245,8 @@
   } else {
     EXPECT_FALSE(r()->Resolve()) << r()->error();
     EXPECT_EQ(r()->error(),
-              "error: decoration is not valid for function return types");
+              "error: decoration is not valid for non-entry point function "
+              "return types");
   }
 }
 INSTANTIATE_TEST_SUITE_P(
diff --git a/src/resolver/entry_point_validation_test.cc b/src/resolver/entry_point_validation_test.cc
index 6b63223..bd0201f 100644
--- a/src/resolver/entry_point_validation_test.cc
+++ b/src/resolver/entry_point_validation_test.cc
@@ -289,16 +289,6 @@
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
-TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Builtin) {
-  // [[stage(fragment)]]
-  // fn main([[builtin(frag_depth)]] param : f32) {}
-  auto* param = Param("param", ty.f32(), {Builtin(ast::Builtin::kFragDepth)});
-  Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
-       {Stage(ast::PipelineStage::kFragment)});
-
-  EXPECT_TRUE(r()->Resolve()) << r()->error();
-}
-
 TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) {
   // [[stage(fragment)]]
   // fn main(param : f32) {}
@@ -313,10 +303,10 @@
 
 TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) {
   // [[stage(fragment)]]
-  // fn main([[location(0)]] [[builtin(vertex_index)]] param : u32) {}
+  // fn main([[location(0)]] [[builtin(sample_index)]] param : u32) {}
   auto* param = Param("param", ty.u32(),
                       {Location(Source{{13, 43}}, 0),
-                       Builtin(Source{{14, 52}}, ast::Builtin::kVertexIndex)});
+                       Builtin(Source{{14, 52}}, ast::Builtin::kSampleIndex)});
   Func(Source{{12, 34}}, "main", {param}, ty.void_(), {},
        {Stage(ast::PipelineStage::kFragment)});
 
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 7a4c7a7..7002b5a 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -928,20 +928,15 @@
 
   for (auto* deco : info->declaration->decorations()) {
     if (!func->IsEntryPoint() && !deco->Is<ast::InternalDecoration>()) {
-      AddError("decoration is not valid for function parameters",
-               deco->source());
+      AddError(
+          "decoration is not valid for non-entry point function parameters",
+          deco->source());
       return false;
-    }
-
-    if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
-      if (!ValidateBuiltinDecoration(builtin, info->type)) {
-        return false;
-      }
     } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
       if (!ValidateInterpolateDecoration(interpolate, info->type)) {
         return false;
       }
-    } else if (!deco->IsAnyOf<ast::LocationDecoration,
+    } else if (!deco->IsAnyOf<ast::LocationDecoration, ast::BuiltinDecoration,
                               ast::InternalDecoration>() &&
                (IsValidationEnabled(
                     info->declaration->decorations(),
@@ -989,10 +984,25 @@
 }
 
 bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
-                                         const sem::Type* storage_type) {
+                                         const sem::Type* storage_type,
+                                         const bool is_input) {
   auto* type = storage_type->UnwrapRef();
+  const auto stage = current_function_
+                         ? current_function_->declaration->pipeline_stage()
+                         : ast::PipelineStage::kNone;
+  std::stringstream stage_name;
+  stage_name << stage;
+  bool is_stage_mismatch = false;
   switch (deco->value()) {
     case ast::Builtin::kPosition:
+      if (stage != ast::PipelineStage::kNone &&
+          !(stage == ast::PipelineStage::kFragment && is_input) &&
+          !(stage == ast::PipelineStage::kVertex && !is_input)) {
+        AddError(deco_to_str(deco) + " cannot be used in " +
+                     (is_input ? "input of " : "output of ") +
+                     stage_name.str() + " pipeline stage",
+                 deco->source());
+      }
       if (!(type->is_float_vector() && type->As<sem::Vector>()->size() == 4)) {
         AddError("store type of " + deco_to_str(deco) + " must be 'vec4<f32>'",
                  deco->source());
@@ -1002,6 +1012,10 @@
     case ast::Builtin::kGlobalInvocationId:
     case ast::Builtin::kLocalInvocationId:
     case ast::Builtin::kWorkgroupId:
+      if (stage != ast::PipelineStage::kNone &&
+          !(stage == ast::PipelineStage::kCompute && is_input)) {
+        is_stage_mismatch = true;
+      }
       if (!(type->is_unsigned_integer_vector() &&
             type->As<sem::Vector>()->size() == 3)) {
         AddError("store type of " + deco_to_str(deco) + " must be 'vec3<u32>'",
@@ -1010,6 +1024,10 @@
       }
       break;
     case ast::Builtin::kFragDepth:
+      if (stage != ast::PipelineStage::kNone &&
+          !(stage == ast::PipelineStage::kFragment && !is_input)) {
+        is_stage_mismatch = true;
+      }
       if (!type->Is<sem::F32>()) {
         AddError("store type of " + deco_to_str(deco) + " must be 'f32'",
                  deco->source());
@@ -1017,6 +1035,10 @@
       }
       break;
     case ast::Builtin::kFrontFacing:
+      if (stage != ast::PipelineStage::kNone &&
+          !(stage == ast::PipelineStage::kFragment && is_input)) {
+        is_stage_mismatch = true;
+      }
       if (!type->Is<sem::Bool>()) {
         AddError("store type of " + deco_to_str(deco) + " must be 'bool'",
                  deco->source());
@@ -1024,10 +1046,44 @@
       }
       break;
     case ast::Builtin::kLocalInvocationIndex:
+      if (stage != ast::PipelineStage::kNone &&
+          !(stage == ast::PipelineStage::kCompute && is_input)) {
+        is_stage_mismatch = true;
+      }
+      if (!type->Is<sem::U32>()) {
+        AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
+                 deco->source());
+        return false;
+      }
+      break;
     case ast::Builtin::kVertexIndex:
     case ast::Builtin::kInstanceIndex:
+      if (stage != ast::PipelineStage::kNone &&
+          !(stage == ast::PipelineStage::kVertex && is_input)) {
+        is_stage_mismatch = true;
+      }
+      if (!type->Is<sem::U32>()) {
+        AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
+                 deco->source());
+        return false;
+      }
+      break;
     case ast::Builtin::kSampleMask:
+      if (stage != ast::PipelineStage::kNone &&
+          !(stage == ast::PipelineStage::kFragment)) {
+        is_stage_mismatch = true;
+      }
+      if (!type->Is<sem::U32>()) {
+        AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
+                 deco->source());
+        return false;
+      }
+      break;
     case ast::Builtin::kSampleIndex:
+      if (stage != ast::PipelineStage::kNone &&
+          !(stage == ast::PipelineStage::kFragment && is_input)) {
+        is_stage_mismatch = true;
+      }
       if (!type->Is<sem::U32>()) {
         AddError("store type of " + deco_to_str(deco) + " must be 'u32'",
                  deco->source());
@@ -1037,6 +1093,15 @@
     default:
       break;
   }
+
+  if (is_stage_mismatch) {
+    AddError(deco_to_str(deco) + " cannot be used in " +
+                 (is_input ? "input of " : "output of ") + stage_name.str() +
+                 " pipeline stage",
+             deco->source());
+    return false;
+  }
+
   return true;
 }
 
@@ -1070,12 +1135,9 @@
     return false;
   }
 
-  auto stage_deco_count = 0;
   auto workgroup_deco_count = 0;
   for (auto* deco : func->decorations()) {
-    if (deco->Is<ast::StageDecoration>()) {
-      stage_deco_count++;
-    } else if (deco->Is<ast::WorkgroupDecoration>()) {
+    if (deco->Is<ast::WorkgroupDecoration>()) {
       workgroup_deco_count++;
       if (func->pipeline_stage() != ast::PipelineStage::kCompute) {
         AddError(
@@ -1083,7 +1145,8 @@
             deco->source());
         return false;
       }
-    } else if (!deco->Is<ast::InternalDecoration>()) {
+    } else if (!deco->IsAnyOf<ast::StageDecoration,
+                              ast::InternalDecoration>()) {
       AddError("decoration is not valid for functions", deco->source());
       return false;
     }
@@ -1119,20 +1182,24 @@
 
     for (auto* deco : func->return_type_decorations()) {
       if (!func->IsEntryPoint()) {
-        AddError("decoration is not valid for function return types",
-                 deco->source());
+        AddError(
+            "decoration is not valid for non-entry point function return types",
+            deco->source());
         return false;
       }
 
-      if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
-        if (!ValidateBuiltinDecoration(builtin, info->return_type)) {
-          return false;
-        }
-      } else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
+      if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
         if (!ValidateInterpolateDecoration(interpolate, info->return_type)) {
           return false;
         }
-      } else if (!deco->Is<ast::LocationDecoration>()) {
+      } else if (!deco->IsAnyOf<ast::LocationDecoration, ast::BuiltinDecoration,
+                                ast::InternalDecoration>() &&
+                 (IsValidationEnabled(
+                      info->declaration->decorations(),
+                      ast::DisabledValidation::kEntryPointParameter) &&
+                  IsValidationEnabled(info->declaration->decorations(),
+                                      ast::DisabledValidation::
+                                          kIgnoreAtomicFunctionParameter))) {
         AddError("decoration is not valid for entry point return types",
                  deco->source());
         return false;
@@ -1192,6 +1259,12 @@
             }
             builtins.emplace(builtin->value());
 
+            if (!ValidateBuiltinDecoration(builtin, ty,
+                                           /* is_input */ param_or_ret ==
+                                               ParamOrRetType::kParameter)) {
+              return false;
+            }
+
           } else if (auto* location = deco->As<ast::LocationDecoration>()) {
             if (pipeline_io_attribute) {
               AddError("multiple entry point IO attributes", deco->source());
@@ -1409,7 +1482,6 @@
       return false;
     }
 
-    // TODO(amaiorano): Validate parameter decorations
     for (auto* deco : param->decorations()) {
       Mark(deco);
     }
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index e8e553f..3306294 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -273,7 +273,8 @@
   bool ValidateAtomicUses();
   bool ValidateAssignment(const ast::AssignmentStatement* a);
   bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
-                                 const sem::Type* storage_type);
+                                 const sem::Type* storage_type,
+                                 const bool is_input = true);
   bool ValidateCallStatement(ast::CallStatement* stmt);
   bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
   bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc
index fee3f57..32b01d2 100644
--- a/src/resolver/validation_test.cc
+++ b/src/resolver/validation_test.cc
@@ -85,9 +85,8 @@
   // fn f2(){ dst = wg; }
   // fn f1() { f2(); }
   // [[stage(fragment)]]
-  // fn f0() -> [[builtin(position)]] vec4<f32> {
+  // fn f0() {
   //  f1();
-  //  return dst;
   //}
 
   Global(Source{{1, 2}}, "wg", ty.vec4<f32>(), ast::StorageClass::kWorkgroup);
@@ -97,10 +96,9 @@
   Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(), {stmt});
   Func(Source{{7, 8}}, "f1", ast::VariableList{}, ty.void_(),
        {Ignore(Call("f2"))});
-  Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4<f32>(),
-       {Ignore(Call("f1")), Return(Expr("dst"))},
-       ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
-       ast::DecorationList{Builtin(ast::Builtin::kPosition)});
+  Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.void_(),
+       {Ignore(Call("f1"))},
+       ast::DecorationList{Stage(ast::PipelineStage::kFragment)});
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index ebadbb4..01cc199 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -296,7 +296,7 @@
 TEST_F(BuilderTest, SampleIndex_SampleRateShadingCapability) {
   Func("main",
        {Param("sample_index", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})},
-       ty.void_(), {}, {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
+       ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)});
 
   spirv::Builder& b = SanitizeAndBuild();