validator: Validate attributes on non-entry point functions

Also validate that workgroup_size is only applied to compute stages,
and not duplicated.

Fixed: tint:703
Change-Id: I02f4ddea305cad25ee0a99e13dc9e7fd1d5dc3ea
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51120
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: James Price <jrprice@google.com>
diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc
index 85a26ea..aac4784 100644
--- a/src/resolver/decoration_validation_test.cc
+++ b/src/resolver/decoration_validation_test.cc
@@ -332,7 +332,6 @@
 
   ast::DecorationList decos =
       createDecorations(Source{{12, 34}}, *this, params.kind);
-  decos.emplace_back(Stage(ast::PipelineStage::kCompute));
   Func("foo", ast::VariableList{}, ty.void_(), ast::StatementList{}, decos);
 
   if (params.should_pass) {
@@ -355,10 +354,10 @@
                     TestParams{DecorationKind::kOverride, false},
                     TestParams{DecorationKind::kOffset, false},
                     TestParams{DecorationKind::kSize, false},
-                    // Skip kStage as we always apply it in this test
+                    // Skip kStage as we do not apply it in this test
                     TestParams{DecorationKind::kStride, false},
                     TestParams{DecorationKind::kStructBlock, false},
-                    TestParams{DecorationKind::kWorkgroup, true},
+                    // Skip kWorkgroup as this is a different error
                     TestParams{DecorationKind::kBindingAndGroup, false}));
 
 }  // namespace
@@ -658,5 +657,46 @@
 }  // namespace
 }  // namespace ResourceTests
 
+namespace WorkgroupDecorationTests {
+namespace {
+
+using WorkgroupDecoration = ResolverTest;
+
+TEST_F(WorkgroupDecoration, NotAnEntryPoint) {
+  Func("main", {}, ty.void_(), {},
+       {create<ast::WorkgroupDecoration>(Source{{12, 34}}, 1u)});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: the workgroup_size attribute is only valid for "
+            "compute stages");
+}
+
+TEST_F(WorkgroupDecoration, NotAComputeShader) {
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kFragment),
+        create<ast::WorkgroupDecoration>(Source{{12, 34}}, 1u)});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: the workgroup_size attribute is only valid for "
+            "compute stages");
+}
+
+TEST_F(WorkgroupDecoration, MultipleAttributes) {
+  Func(Source{{12, 34}}, "main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        create<ast::WorkgroupDecoration>(1u),
+        create<ast::WorkgroupDecoration>(2u)});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: only one workgroup_size attribute permitted per "
+            "entry point");
+}
+
+}  // namespace
+}  // namespace WorkgroupDecorationTests
+
 }  // namespace resolver
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index d5166d9..145bdc0 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -823,6 +823,38 @@
     return false;
   }
 
+  auto stage_deco_count = 0;
+  auto workgroup_deco_count = 0;
+  for (auto* deco : func->decorations()) {
+    if (deco->Is<ast::StageDecoration>()) {
+      stage_deco_count++;
+    } else if (deco->Is<ast::WorkgroupDecoration>()) {
+      workgroup_deco_count++;
+      if (func->pipeline_stage() != ast::PipelineStage::kCompute) {
+        diagnostics_.add_error(
+            "the workgroup_size attribute is only valid for compute stages",
+            deco->source());
+        return false;
+      }
+    } else if (!deco->Is<ast::InternalDecoration>()) {
+      diagnostics_.add_error("decoration is not valid for functions",
+                             deco->source());
+      return false;
+    }
+  }
+  if (stage_deco_count > 1) {
+    diagnostics_.add_error(
+        "v-0020", "only one stage decoration permitted per entry point",
+        func->source());
+    return false;
+  }
+  if (workgroup_deco_count > 1) {
+    diagnostics_.add_error(
+        "only one workgroup_size attribute permitted per entry point",
+        func->source());
+    return false;
+  }
+
   for (auto* param : func->params()) {
     if (!ValidateParameter(variable_to_info_.at(param))) {
       return false;
@@ -867,23 +899,6 @@
 
 bool Resolver::ValidateEntryPoint(const ast::Function* func,
                                   const FunctionInfo* info) {
-  auto stage_deco_count = 0;
-  for (auto* deco : func->decorations()) {
-    if (deco->Is<ast::StageDecoration>()) {
-      stage_deco_count++;
-    } else if (!deco->Is<ast::WorkgroupDecoration>()) {
-      diagnostics_.add_error("decoration is not valid for functions",
-                             deco->source());
-      return false;
-    }
-  }
-  if (stage_deco_count > 1) {
-    diagnostics_.add_error(
-        "v-0020", "only one stage decoration permitted per entry point",
-        func->source());
-    return false;
-  }
-
   // Use a lambda to validate the entry point decorations for a type.
   // Persistent state is used to track which builtins and locations have already
   // been seen, in order to catch conflicts.
diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc
index 16b1c0d..fc69ccf 100644
--- a/src/writer/wgsl/generator_impl_function_test.cc
+++ b/src/writer/wgsl/generator_impl_function_test.cc
@@ -74,6 +74,7 @@
                         Return(),
                     },
                     ast::DecorationList{
+                        Stage(ast::PipelineStage::kCompute),
                         create<ast::WorkgroupDecoration>(2u, 4u, 6u),
                     });
 
@@ -82,7 +83,7 @@
   gen.increment_indent();
 
   ASSERT_TRUE(gen.EmitFunction(func));
-  EXPECT_EQ(gen.result(), R"(  [[workgroup_size(2, 4, 6)]]
+  EXPECT_EQ(gen.result(), R"(  [[stage(compute), workgroup_size(2, 4, 6)]]
   fn my_func() {
     discard;
     return;
@@ -113,30 +114,6 @@
 )");
 }
 
-TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) {
-  auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
-                    ast::StatementList{
-                        create<ast::DiscardStatement>(),
-                        Return(),
-                    },
-                    ast::DecorationList{
-                        Stage(ast::PipelineStage::kFragment),
-                        create<ast::WorkgroupDecoration>(2u, 4u, 6u),
-                    });
-
-  GeneratorImpl& gen = Build();
-
-  gen.increment_indent();
-
-  ASSERT_TRUE(gen.EmitFunction(func));
-  EXPECT_EQ(gen.result(), R"(  [[stage(fragment), workgroup_size(2, 4, 6)]]
-  fn my_func() {
-    discard;
-    return;
-  }
-)");
-}
-
 TEST_F(WgslGeneratorImplTest, Emit_Function_EntryPoint_Parameters) {
   auto vec4 = ty.vec4<f32>();
   auto* coord = Param("coord", vec4, {Builtin(ast::Builtin::kPosition)});