[wgsl-reader] Allow decorations on function return types
Add a return type decoration list field to ast::Function.
Bug: tint:513
Change-Id: I41c1087f21a87731eb48ec7642997da5ae7f2baa
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44601
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 069b66c..b2d3174 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -28,13 +28,15 @@
VariableList params,
type::Type* return_type,
BlockStatement* body,
- DecorationList decorations)
+ DecorationList decorations,
+ DecorationList return_type_decorations)
: Base(source),
symbol_(symbol),
params_(std::move(params)),
return_type_(return_type),
body_(body),
- decorations_(std::move(decorations)) {
+ decorations_(std::move(decorations)),
+ return_type_decorations_(std::move(return_type_decorations)) {
for (auto* param : params_) {
TINT_ASSERT(param);
}
@@ -77,7 +79,8 @@
auto* ret = ctx->Clone(return_type_);
auto* b = ctx->Clone(body_);
auto decos = ctx->Clone(decorations_);
- return ctx->dst->create<Function>(src, sym, p, ret, b, decos);
+ auto ret_decos = ctx->Clone(return_type_decorations_);
+ return ctx->dst->create<Function>(src, sym, p, ret, b, decos, ret_decos);
}
void Function::to_str(const semantic::Info& sem,
diff --git a/src/ast/function.h b/src/ast/function.h
index c3db11f..eb856c3 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -42,12 +42,14 @@
/// @param return_type the return type
/// @param body the function body
/// @param decorations the function decorations
+ /// @param return_type_decorations the return type decorations
Function(const Source& source,
Symbol symbol,
VariableList params,
type::Type* return_type,
BlockStatement* body,
- DecorationList decorations);
+ DecorationList decorations,
+ DecorationList return_type_decorations);
/// Move constructor
Function(Function&&);
@@ -74,6 +76,11 @@
/// @returns the function return type.
type::Type* return_type() const { return return_type_; }
+ /// @returns the decorations attached to the function return type.
+ const DecorationList& return_type_decorations() const {
+ return return_type_decorations_;
+ }
+
/// @returns a pointer to the last statement of the function or nullptr if
// function is empty
const Statement* get_last_statement() const;
@@ -108,6 +115,7 @@
type::Type* const return_type_;
BlockStatement* const body_;
DecorationList const decorations_;
+ DecorationList const return_type_decorations_;
};
/// A list of functions
diff --git a/src/program_builder.h b/src/program_builder.h
index 39832a2..a8779a8 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -994,16 +994,18 @@
/// @param type the function return type
/// @param body the function body
/// @param decorations the function decorations
+ /// @param return_type_decorations the function return type decorations
/// @returns the function pointer
ast::Function* Func(Source source,
std::string name,
ast::VariableList params,
type::Type* type,
ast::StatementList body,
- ast::DecorationList decorations) {
- auto* func =
- create<ast::Function>(source, Symbols().Register(name), params, type,
- create<ast::BlockStatement>(body), decorations);
+ ast::DecorationList decorations,
+ ast::DecorationList return_type_decorations = {}) {
+ auto* func = create<ast::Function>(source, Symbols().Register(name), params,
+ type, create<ast::BlockStatement>(body),
+ decorations, return_type_decorations);
AST().AddFunction(func);
return func;
}
@@ -1014,15 +1016,17 @@
/// @param type the function return type
/// @param body the function body
/// @param decorations the function decorations
+ /// @param return_type_decorations the function return type decorations
/// @returns the function pointer
ast::Function* Func(std::string name,
ast::VariableList params,
type::Type* type,
ast::StatementList body,
- ast::DecorationList decorations) {
- auto* func =
- create<ast::Function>(Symbols().Register(name), params, type,
- create<ast::BlockStatement>(body), decorations);
+ ast::DecorationList decorations,
+ ast::DecorationList return_type_decorations = {}) {
+ auto* func = create<ast::Function>(Symbols().Register(name), params, type,
+ create<ast::BlockStatement>(body),
+ decorations, return_type_decorations);
AST().AddFunction(func);
return func;
}
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 350ef4b..63a124e 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -843,10 +843,10 @@
auto& statements = statements_stack_[0].GetStatements();
auto* body = create<ast::BlockStatement>(Source{}, statements);
- builder_.AST().AddFunction(
- create<ast::Function>(decl.source, builder_.Symbols().Register(decl.name),
- std::move(decl.params), decl.return_type, body,
- std::move(decl.decorations)));
+ builder_.AST().AddFunction(create<ast::Function>(
+ decl.source, builder_.Symbols().Register(decl.name),
+ std::move(decl.params), decl.return_type, body,
+ std::move(decl.decorations), ast::DecorationList{}));
// Maintain the invariant by repopulating the one and only element.
statements_stack_.clear();
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 34fc690..6010c47 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -167,8 +167,13 @@
ParserImpl::FunctionHeader::FunctionHeader(Source src,
std::string n,
ast::VariableList p,
- type::Type* ret_ty)
- : source(src), name(n), params(p), return_type(ret_ty) {}
+ type::Type* ret_ty,
+ ast::DecorationList ret_decos)
+ : source(src),
+ name(n),
+ params(p),
+ return_type(ret_ty),
+ return_type_decorations(ret_decos) {}
ParserImpl::FunctionHeader::~FunctionHeader() = default;
@@ -1185,7 +1190,7 @@
return create<ast::Function>(
header->source, builder_.Symbols().Register(header->name), header->params,
- header->return_type, body.value, decos);
+ header->return_type, body.value, decos, header->return_type_decorations);
}
// function_type_decl
@@ -1225,6 +1230,11 @@
if (!expect(use, Token::Type::kArrow))
return Failure::kErrored;
+ auto decos = decoration_list();
+ if (decos.errored) {
+ return Failure::kErrored;
+ }
+
auto type = function_type_decl();
if (type.errored) {
errored = true;
@@ -1235,8 +1245,8 @@
if (errored)
return Failure::kErrored;
- return FunctionHeader{source, name.value, std::move(params.value),
- type.value};
+ return FunctionHeader{source, name.value, std::move(params.value), type.value,
+ std::move(decos.value)};
}
// param_list
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index a2bf63d..080dda5 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -218,10 +218,12 @@
/// @param n function name
/// @param p function parameters
/// @param ret_ty function return type
+ /// @param ret_decos return type decorations
FunctionHeader(Source src,
std::string n,
ast::VariableList p,
- type::Type* ret_ty);
+ type::Type* ret_ty,
+ ast::DecorationList ret_decos);
/// Destructor
~FunctionHeader();
/// Assignment operator
@@ -237,6 +239,8 @@
ast::VariableList params;
/// Function return type
type::Type* return_type;
+ /// Function return type decorations
+ ast::DecorationList return_type_decorations;
};
/// VarDeclInfo contains the parsed information for variable declaration.
diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc
index 86d0330..42c1a43 100644
--- a/src/reader/wgsl/parser_impl_function_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decl_test.cc
@@ -173,6 +173,37 @@
EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
}
+TEST_F(ParserImplTest, FunctionDecl_ReturnTypeDecorationList) {
+ auto p = parser("fn main() -> [[location(1)]] f32 { return 1.0; }");
+ auto decos = p->decoration_list();
+ EXPECT_FALSE(p->has_error()) << p->error();
+ ASSERT_FALSE(decos.errored);
+ EXPECT_FALSE(decos.matched);
+ auto f = p->function_decl(decos.value);
+ EXPECT_FALSE(p->has_error()) << p->error();
+ EXPECT_FALSE(f.errored);
+ EXPECT_TRUE(f.matched);
+ ASSERT_NE(f.value, nullptr);
+
+ EXPECT_EQ(f->symbol(), p->builder().Symbols().Get("main"));
+ ASSERT_NE(f->return_type(), nullptr);
+ EXPECT_TRUE(f->return_type()->Is<type::F32>());
+ ASSERT_EQ(f->params().size(), 0u);
+
+ auto& decorations = f->decorations();
+ EXPECT_EQ(decorations.size(), 0u);
+
+ auto& ret_type_decorations = f->return_type_decorations();
+ ASSERT_EQ(ret_type_decorations.size(), 1u);
+ auto* loc = ret_type_decorations[0]->As<ast::LocationDecoration>();
+ ASSERT_TRUE(loc != nullptr);
+ EXPECT_EQ(loc->value(), 1u);
+
+ auto* body = f->body();
+ ASSERT_EQ(body->size(), 1u);
+ EXPECT_TRUE(body->get(0)->Is<ast::ReturnStatement>());
+}
+
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
auto p = parser("fn main() -> { }");
auto decos = p->decoration_list();
diff --git a/src/reader/wgsl/parser_impl_function_header_test.cc b/src/reader/wgsl/parser_impl_function_header_test.cc
index cedceb0..7884613 100644
--- a/src/reader/wgsl/parser_impl_function_header_test.cc
+++ b/src/reader/wgsl/parser_impl_function_header_test.cc
@@ -33,6 +33,22 @@
EXPECT_TRUE(f->return_type->Is<type::Void>());
}
+TEST_F(ParserImplTest, FunctionHeader_DecoratedReturnType) {
+ auto p = parser("fn main() -> [[location(1)]] f32");
+ auto f = p->function_header();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ EXPECT_TRUE(f.matched);
+ EXPECT_FALSE(f.errored);
+
+ EXPECT_EQ(f->name, "main");
+ EXPECT_EQ(f->params.size(), 0u);
+ EXPECT_TRUE(f->return_type->Is<type::F32>());
+ ASSERT_TRUE(f->return_type_decorations.size() == 1u);
+ auto* loc = f->return_type_decorations[0]->As<ast::LocationDecoration>();
+ ASSERT_TRUE(loc != nullptr);
+ EXPECT_EQ(loc->value(), 1u);
+}
+
TEST_F(ParserImplTest, FunctionHeader_MissingIdent) {
auto p = parser("fn () -> void");
auto f = p->function_header();
diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc
index 6da5a2a..ee7f3b3 100644
--- a/src/transform/hlsl.cc
+++ b/src/transform/hlsl.cc
@@ -229,7 +229,8 @@
func->source(), ctx.Clone(func->symbol()), new_parameters,
ctx.Clone(func->return_type()),
ctx.dst->create<ast::BlockStatement>(new_body),
- ctx.Clone(func->decorations()));
+ ctx.Clone(func->decorations()),
+ ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func);
}
}
diff --git a/src/transform/msl.cc b/src/transform/msl.cc
index 28cae98..96f5f04 100644
--- a/src/transform/msl.cc
+++ b/src/transform/msl.cc
@@ -391,7 +391,8 @@
func->source(), ctx.Clone(func->symbol()), new_parameters,
ctx.Clone(func->return_type()),
ctx.dst->create<ast::BlockStatement>(new_body),
- ctx.Clone(func->decorations()));
+ ctx.Clone(func->decorations()),
+ ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func);
}
}
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index d990f55..a6f41fe 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -138,7 +138,8 @@
auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
ctx.Clone(func->return_type()), ctx.Clone(func->body()),
- ctx.Clone(func->decorations()));
+ ctx.Clone(func->decorations()),
+ ctx.Clone(func->return_type_decorations()));
ctx.Replace(func, new_func);
}
}
diff --git a/src/transform/transform.cc b/src/transform/transform.cc
index e1e8210..fbf3e13 100644
--- a/src/transform/transform.cc
+++ b/src/transform/transform.cc
@@ -58,8 +58,9 @@
auto* body = ctx->dst->create<ast::BlockStatement>(
ctx->Clone(in->body()->source()), statements);
auto decos = ctx->Clone(in->decorations());
+ auto ret_decos = ctx->Clone(in->return_type_decorations());
return ctx->dst->create<ast::Function>(source, symbol, params, return_type,
- body, decos);
+ body, decos, ret_decos);
}
void Transform::RenameReservedKeywords(CloneContext* ctx,
diff --git a/src/validator/validator_decoration_test.cc b/src/validator/validator_decoration_test.cc
index 67362c4..b4dc099 100644
--- a/src/validator/validator_decoration_test.cc
+++ b/src/validator/validator_decoration_test.cc
@@ -148,6 +148,41 @@
false},
DecorationTestParams{DecorationKind::kWorkgroup, true}));
+using FunctionReturnTypeDecorationTest = ValidatorDecorationsTestWithParams;
+TEST_P(FunctionReturnTypeDecorationTest, Decoration_IsValid) {
+ auto params = GetParam();
+
+ Func("main", ast::VariableList{}, ty.f32(),
+ ast::StatementList{create<ast::ReturnStatement>(Expr(1.f))},
+ ast::DecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kVertex)},
+ ast::DecorationList{createDecoration(*this, params.kind)});
+
+ ValidatorImpl& v = Build();
+
+ if (params.should_pass) {
+ EXPECT_TRUE(v.Validate());
+ } else {
+ EXPECT_FALSE(v.Validate());
+ EXPECT_EQ(v.error(), "decoration is not valid for function return types");
+ }
+}
+INSTANTIATE_TEST_SUITE_P(
+ ValidatorTest,
+ FunctionReturnTypeDecorationTest,
+ testing::Values(DecorationTestParams{DecorationKind::kAccess, false},
+ DecorationTestParams{DecorationKind::kBinding, false},
+ DecorationTestParams{DecorationKind::kBuiltin, true},
+ DecorationTestParams{DecorationKind::kConstantId, false},
+ DecorationTestParams{DecorationKind::kGroup, false},
+ DecorationTestParams{DecorationKind::kLocation, true},
+ DecorationTestParams{DecorationKind::kStage, false},
+ DecorationTestParams{DecorationKind::kStride, false},
+ DecorationTestParams{DecorationKind::kStructBlock, false},
+ DecorationTestParams{DecorationKind::kStructMemberOffset,
+ false},
+ DecorationTestParams{DecorationKind::kWorkgroup, false}));
+
using StructDecorationTest = ValidatorDecorationsTestWithParams;
TEST_P(StructDecorationTest, Decoration_IsValid) {
auto params = GetParam();
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index 56f14a7..1b18e68 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -248,6 +248,15 @@
"non-void function must end with a return statement");
return false;
}
+
+ for (auto* deco : current_function_->return_type_decorations()) {
+ if (!(deco->Is<ast::BuiltinDecoration>() ||
+ deco->Is<ast::LocationDecoration>())) {
+ add_error(deco->source(),
+ "decoration is not valid for function return types");
+ return false;
+ }
+ }
}
return true;
}