[wgsl-reader] Allow decorations on function return types

Add a return type decoration list field to ast::Function.

Bug: tint:513
Change-Id: I41c1087f21a87731eb48ec7642997da5ae7f2baa
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44601
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 069b66c..b2d3174 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -28,13 +28,15 @@
                    VariableList params,
                    type::Type* return_type,
                    BlockStatement* body,
-                   DecorationList decorations)
+                   DecorationList decorations,
+                   DecorationList return_type_decorations)
     : Base(source),
       symbol_(symbol),
       params_(std::move(params)),
       return_type_(return_type),
       body_(body),
-      decorations_(std::move(decorations)) {
+      decorations_(std::move(decorations)),
+      return_type_decorations_(std::move(return_type_decorations)) {
   for (auto* param : params_) {
     TINT_ASSERT(param);
   }
@@ -77,7 +79,8 @@
   auto* ret = ctx->Clone(return_type_);
   auto* b = ctx->Clone(body_);
   auto decos = ctx->Clone(decorations_);
-  return ctx->dst->create<Function>(src, sym, p, ret, b, decos);
+  auto ret_decos = ctx->Clone(return_type_decorations_);
+  return ctx->dst->create<Function>(src, sym, p, ret, b, decos, ret_decos);
 }
 
 void Function::to_str(const semantic::Info& sem,
diff --git a/src/ast/function.h b/src/ast/function.h
index c3db11f..eb856c3 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -42,12 +42,14 @@
   /// @param return_type the return type
   /// @param body the function body
   /// @param decorations the function decorations
+  /// @param return_type_decorations the return type decorations
   Function(const Source& source,
            Symbol symbol,
            VariableList params,
            type::Type* return_type,
            BlockStatement* body,
-           DecorationList decorations);
+           DecorationList decorations,
+           DecorationList return_type_decorations);
   /// Move constructor
   Function(Function&&);
 
@@ -74,6 +76,11 @@
   /// @returns the function return type.
   type::Type* return_type() const { return return_type_; }
 
+  /// @returns the decorations attached to the function return type.
+  const DecorationList& return_type_decorations() const {
+    return return_type_decorations_;
+  }
+
   /// @returns a pointer to the last statement of the function or nullptr if
   // function is empty
   const Statement* get_last_statement() const;
@@ -108,6 +115,7 @@
   type::Type* const return_type_;
   BlockStatement* const body_;
   DecorationList const decorations_;
+  DecorationList const return_type_decorations_;
 };
 
 /// A list of functions
diff --git a/src/program_builder.h b/src/program_builder.h
index 39832a2..a8779a8 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -994,16 +994,18 @@
   /// @param type the function return type
   /// @param body the function body
   /// @param decorations the function decorations
+  /// @param return_type_decorations the function return type decorations
   /// @returns the function pointer
   ast::Function* Func(Source source,
                       std::string name,
                       ast::VariableList params,
                       type::Type* type,
                       ast::StatementList body,
-                      ast::DecorationList decorations) {
-    auto* func =
-        create<ast::Function>(source, Symbols().Register(name), params, type,
-                              create<ast::BlockStatement>(body), decorations);
+                      ast::DecorationList decorations,
+                      ast::DecorationList return_type_decorations = {}) {
+    auto* func = create<ast::Function>(source, Symbols().Register(name), params,
+                                       type, create<ast::BlockStatement>(body),
+                                       decorations, return_type_decorations);
     AST().AddFunction(func);
     return func;
   }
@@ -1014,15 +1016,17 @@
   /// @param type the function return type
   /// @param body the function body
   /// @param decorations the function decorations
+  /// @param return_type_decorations the function return type decorations
   /// @returns the function pointer
   ast::Function* Func(std::string name,
                       ast::VariableList params,
                       type::Type* type,
                       ast::StatementList body,
-                      ast::DecorationList decorations) {
-    auto* func =
-        create<ast::Function>(Symbols().Register(name), params, type,
-                              create<ast::BlockStatement>(body), decorations);
+                      ast::DecorationList decorations,
+                      ast::DecorationList return_type_decorations = {}) {
+    auto* func = create<ast::Function>(Symbols().Register(name), params, type,
+                                       create<ast::BlockStatement>(body),
+                                       decorations, return_type_decorations);
     AST().AddFunction(func);
     return func;
   }
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 350ef4b..63a124e 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -843,10 +843,10 @@
 
   auto& statements = statements_stack_[0].GetStatements();
   auto* body = create<ast::BlockStatement>(Source{}, statements);
-  builder_.AST().AddFunction(
-      create<ast::Function>(decl.source, builder_.Symbols().Register(decl.name),
-                            std::move(decl.params), decl.return_type, body,
-                            std::move(decl.decorations)));
+  builder_.AST().AddFunction(create<ast::Function>(
+      decl.source, builder_.Symbols().Register(decl.name),
+      std::move(decl.params), decl.return_type, body,
+      std::move(decl.decorations), ast::DecorationList{}));
 
   // Maintain the invariant by repopulating the one and only element.
   statements_stack_.clear();
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 34fc690..6010c47 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -167,8 +167,13 @@
 ParserImpl::FunctionHeader::FunctionHeader(Source src,
                                            std::string n,
                                            ast::VariableList p,
-                                           type::Type* ret_ty)
-    : source(src), name(n), params(p), return_type(ret_ty) {}
+                                           type::Type* ret_ty,
+                                           ast::DecorationList ret_decos)
+    : source(src),
+      name(n),
+      params(p),
+      return_type(ret_ty),
+      return_type_decorations(ret_decos) {}
 
 ParserImpl::FunctionHeader::~FunctionHeader() = default;
 
@@ -1185,7 +1190,7 @@
 
   return create<ast::Function>(
       header->source, builder_.Symbols().Register(header->name), header->params,
-      header->return_type, body.value, decos);
+      header->return_type, body.value, decos, header->return_type_decorations);
 }
 
 // function_type_decl
@@ -1225,6 +1230,11 @@
   if (!expect(use, Token::Type::kArrow))
     return Failure::kErrored;
 
+  auto decos = decoration_list();
+  if (decos.errored) {
+    return Failure::kErrored;
+  }
+
   auto type = function_type_decl();
   if (type.errored) {
     errored = true;
@@ -1235,8 +1245,8 @@
   if (errored)
     return Failure::kErrored;
 
-  return FunctionHeader{source, name.value, std::move(params.value),
-                        type.value};
+  return FunctionHeader{source, name.value, std::move(params.value), type.value,
+                        std::move(decos.value)};
 }
 
 // param_list
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index a2bf63d..080dda5 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -218,10 +218,12 @@
     /// @param n function name
     /// @param p function parameters
     /// @param ret_ty function return type
+    /// @param ret_decos return type decorations
     FunctionHeader(Source src,
                    std::string n,
                    ast::VariableList p,
-                   type::Type* ret_ty);
+                   type::Type* ret_ty,
+                   ast::DecorationList ret_decos);
     /// Destructor
     ~FunctionHeader();
     /// Assignment operator
@@ -237,6 +239,8 @@
     ast::VariableList params;
     /// Function return type
     type::Type* return_type;
+    /// Function return type decorations
+    ast::DecorationList return_type_decorations;
   };
 
   /// VarDeclInfo contains the parsed information for variable declaration.
diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc
index 86d0330..42c1a43 100644
--- a/src/reader/wgsl/parser_impl_function_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decl_test.cc
@@ -173,6 +173,37 @@
   EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
 }
 
+TEST_F(ParserImplTest, FunctionDecl_ReturnTypeDecorationList) {
+  auto p = parser("fn main() -> [[location(1)]] f32 { return 1.0; }");
+  auto decos = p->decoration_list();
+  EXPECT_FALSE(p->has_error()) << p->error();
+  ASSERT_FALSE(decos.errored);
+  EXPECT_FALSE(decos.matched);
+  auto f = p->function_decl(decos.value);
+  EXPECT_FALSE(p->has_error()) << p->error();
+  EXPECT_FALSE(f.errored);
+  EXPECT_TRUE(f.matched);
+  ASSERT_NE(f.value, nullptr);
+
+  EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
+  ASSERT_NE(f->return_type(), nullptr);
+  EXPECT_TRUE(f->return_type()->Is<type::F32>());
+  ASSERT_EQ(f->params().size(), 0u);
+
+  auto& decorations = f->decorations();
+  EXPECT_EQ(decorations.size(), 0u);
+
+  auto& ret_type_decorations = f->return_type_decorations();
+  ASSERT_EQ(ret_type_decorations.size(), 1u);
+  auto* loc = ret_type_decorations[0]->As<ast::LocationDecoration>();
+  ASSERT_TRUE(loc != nullptr);
+  EXPECT_EQ(loc->value(), 1u);
+
+  auto* body = f->body();
+  ASSERT_EQ(body->size(), 1u);
+  EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
+}
+
 TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
   auto p = parser("fn main() -> { }");
   auto decos = p->decoration_list();
diff --git a/src/reader/wgsl/parser_impl_function_header_test.cc b/src/reader/wgsl/parser_impl_function_header_test.cc
index cedceb0..7884613 100644
--- a/src/reader/wgsl/parser_impl_function_header_test.cc
+++ b/src/reader/wgsl/parser_impl_function_header_test.cc
@@ -33,6 +33,22 @@
   EXPECT_TRUE(f->return_type->Is<type::Void>());
 }
 
+TEST_F(ParserImplTest, FunctionHeader_DecoratedReturnType) {
+  auto p = parser("fn main() -> [[location(1)]] f32");
+  auto f = p->function_header();
+  ASSERT_FALSE(p->has_error()) << p->error();
+  EXPECT_TRUE(f.matched);
+  EXPECT_FALSE(f.errored);
+
+  EXPECT_EQ(f->name, "main");
+  EXPECT_EQ(f->params.size(), 0u);
+  EXPECT_TRUE(f->return_type->Is<type::F32>());
+  ASSERT_TRUE(f->return_type_decorations.size() == 1u);
+  auto* loc = f->return_type_decorations[0]->As<ast::LocationDecoration>();
+  ASSERT_TRUE(loc != nullptr);
+  EXPECT_EQ(loc->value(), 1u);
+}
+
 TEST_F(ParserImplTest, FunctionHeader_MissingIdent) {
   auto p = parser("fn () -> void");
   auto f = p->function_header();
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index 6da5a2a..ee7f3b3 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -229,7 +229,8 @@
         func->source(), ctx.Clone(func->symbol()), new_parameters,
         ctx.Clone(func->return_type()),
         ctx.dst->create<ast::BlockStatement>(new_body),
-        ctx.Clone(func->decorations()));
+        ctx.Clone(func->decorations()),
+        ctx.Clone(func->return_type_decorations()));
     ctx.Replace(func, new_func);
   }
 }
diff --git a/src/transform/msl.cc b/src/transform/msl.cc
index 28cae98..96f5f04 100644
--- a/src/transform/msl.cc
+++ b/src/transform/msl.cc
@@ -391,7 +391,8 @@
         func->source(), ctx.Clone(func->symbol()), new_parameters,
         ctx.Clone(func->return_type()),
         ctx.dst->create<ast::BlockStatement>(new_body),
-        ctx.Clone(func->decorations()));
+        ctx.Clone(func->decorations()),
+        ctx.Clone(func->return_type_decorations()));
     ctx.Replace(func, new_func);
   }
 }
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index d990f55..a6f41fe 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -138,7 +138,8 @@
     auto* new_func = ctx.dst->create<ast::Function>(
         func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
         ctx.Clone(func->return_type()), ctx.Clone(func->body()),
-        ctx.Clone(func->decorations()));
+        ctx.Clone(func->decorations()),
+        ctx.Clone(func->return_type_decorations()));
     ctx.Replace(func, new_func);
   }
 }
diff --git a/src/transform/transform.cc b/src/transform/transform.cc
index e1e8210..fbf3e13 100644
--- a/src/transform/transform.cc
+++ b/src/transform/transform.cc
@@ -58,8 +58,9 @@
   auto* body = ctx->dst->create<ast::BlockStatement>(
       ctx->Clone(in->body()->source()), statements);
   auto decos = ctx->Clone(in->decorations());
+  auto ret_decos = ctx->Clone(in->return_type_decorations());
   return ctx->dst->create<ast::Function>(source, symbol, params, return_type,
-                                         body, decos);
+                                         body, decos, ret_decos);
 }
 
 void Transform::RenameReservedKeywords(CloneContext* ctx,
diff --git a/src/validator/validator_decoration_test.cc b/src/validator/validator_decoration_test.cc
index 67362c4..b4dc099 100644
--- a/src/validator/validator_decoration_test.cc
+++ b/src/validator/validator_decoration_test.cc
@@ -148,6 +148,41 @@
                                          false},
                     DecorationTestParams{DecorationKind::kWorkgroup, true}));
 
+using FunctionReturnTypeDecorationTest = ValidatorDecorationsTestWithParams;
+TEST_P(FunctionReturnTypeDecorationTest, Decoration_IsValid) {
+  auto params = GetParam();
+
+  Func("main", ast::VariableList{}, ty.f32(),
+       ast::StatementList{create<ast::ReturnStatement>(Expr(1.f))},
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kVertex)},
+       ast::DecorationList{createDecoration(*this, params.kind)});
+
+  ValidatorImpl& v = Build();
+
+  if (params.should_pass) {
+    EXPECT_TRUE(v.Validate());
+  } else {
+    EXPECT_FALSE(v.Validate());
+    EXPECT_EQ(v.error(), "decoration is not valid for function return types");
+  }
+}
+INSTANTIATE_TEST_SUITE_P(
+    ValidatorTest,
+    FunctionReturnTypeDecorationTest,
+    testing::Values(DecorationTestParams{DecorationKind::kAccess, false},
+                    DecorationTestParams{DecorationKind::kBinding, false},
+                    DecorationTestParams{DecorationKind::kBuiltin, true},
+                    DecorationTestParams{DecorationKind::kConstantId, false},
+                    DecorationTestParams{DecorationKind::kGroup, false},
+                    DecorationTestParams{DecorationKind::kLocation, true},
+                    DecorationTestParams{DecorationKind::kStage, false},
+                    DecorationTestParams{DecorationKind::kStride, false},
+                    DecorationTestParams{DecorationKind::kStructBlock, false},
+                    DecorationTestParams{DecorationKind::kStructMemberOffset,
+                                         false},
+                    DecorationTestParams{DecorationKind::kWorkgroup, false}));
+
 using StructDecorationTest = ValidatorDecorationsTestWithParams;
 TEST_P(StructDecorationTest, Decoration_IsValid) {
   auto params = GetParam();
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index 56f14a7..1b18e68 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -248,6 +248,15 @@
                 "non-void function must end with a return statement");
       return false;
     }
+
+    for (auto* deco : current_function_->return_type_decorations()) {
+      if (!(deco->Is<ast::BuiltinDecoration>() ||
+            deco->Is<ast::LocationDecoration>())) {
+        add_error(deco->source(),
+                  "decoration is not valid for function return types");
+        return false;
+      }
+    }
   }
   return true;
 }