[wgsl-reader] Add workgroup_size parsing
This CL adds parsing of the `workgroup_size` function decoration.
Change-Id: Ia90efc2c014ac0e1614429280cc903d30cf8171d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28663
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index a37e2a0..f54e12a 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -900,6 +900,8 @@
"src/reader/wgsl/parser_impl_exclusive_or_expression_test.cc",
"src/reader/wgsl/parser_impl_for_stmt_test.cc",
"src/reader/wgsl/parser_impl_function_decl_test.cc",
+ "src/reader/wgsl/parser_impl_function_decoration_list_test.cc",
+ "src/reader/wgsl/parser_impl_function_decoration_test.cc",
"src/reader/wgsl/parser_impl_function_header_test.cc",
"src/reader/wgsl/parser_impl_function_type_decl_test.cc",
"src/reader/wgsl/parser_impl_global_constant_decl_test.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 04fbea4..c87bd64 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -431,6 +431,8 @@
reader/wgsl/parser_impl_exclusive_or_expression_test.cc
reader/wgsl/parser_impl_for_stmt_test.cc
reader/wgsl/parser_impl_function_decl_test.cc
+ reader/wgsl/parser_impl_function_decoration_list_test.cc
+ reader/wgsl/parser_impl_function_decoration_test.cc
reader/wgsl/parser_impl_function_header_test.cc
reader/wgsl/parser_impl_function_type_decl_test.cc
reader/wgsl/parser_impl_global_constant_decl_test.cc
diff --git a/src/ast/builtin.cc b/src/ast/builtin.cc
index b0844cf..d5818a1 100644
--- a/src/ast/builtin.cc
+++ b/src/ast/builtin.cc
@@ -47,10 +47,6 @@
out << "frag_depth";
break;
}
- case Builtin::kWorkgroupSize: {
- out << "workgroup_size";
- break;
- }
case Builtin::kLocalInvocationId: {
out << "local_invocation_id";
break;
diff --git a/src/ast/builtin.h b/src/ast/builtin.h
index 7fb3fe5..5d0d763 100644
--- a/src/ast/builtin.h
+++ b/src/ast/builtin.h
@@ -29,7 +29,6 @@
kFrontFacing,
kFragCoord,
kFragDepth,
- kWorkgroupSize,
kLocalInvocationId,
kLocalInvocationIdx,
kGlobalInvocationId
diff --git a/src/ast/function.h b/src/ast/function.h
index d64c023..078f692 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -82,6 +82,12 @@
/// @returns the function params
const VariableList& params() const { return params_; }
+ /// Sets the function decorations
+ /// @param decos the decorations to set. This will overwrite any existing
+ /// decorations
+ void set_decorations(ast::FunctionDecorationList decos) {
+ decorations_ = std::move(decos);
+ }
/// Adds a decoration to the function
/// @param deco the decoration to set
void add_decoration(std::unique_ptr<FunctionDecoration> deco) {
diff --git a/src/reader/spirv/enum_converter.cc b/src/reader/spirv/enum_converter.cc
index 2b54362..1efc523 100644
--- a/src/reader/spirv/enum_converter.cc
+++ b/src/reader/spirv/enum_converter.cc
@@ -80,8 +80,6 @@
return ast::Builtin::kFragCoord;
case SpvBuiltInFragDepth:
return ast::Builtin::kFragDepth;
- case SpvBuiltInWorkgroupSize:
- return ast::Builtin::kWorkgroupSize;
case SpvBuiltInLocalInvocationId:
return ast::Builtin::kLocalInvocationId;
case SpvBuiltInLocalInvocationIndex:
diff --git a/src/reader/spirv/enum_converter_test.cc b/src/reader/spirv/enum_converter_test.cc
index fb9c8d0..6d54225 100644
--- a/src/reader/spirv/enum_converter_test.cc
+++ b/src/reader/spirv/enum_converter_test.cc
@@ -215,8 +215,6 @@
BuiltinCase{SpvBuiltInFrontFacing, true, ast::Builtin::kFrontFacing},
BuiltinCase{SpvBuiltInFragCoord, true, ast::Builtin::kFragCoord},
BuiltinCase{SpvBuiltInFragDepth, true, ast::Builtin::kFragDepth},
- BuiltinCase{SpvBuiltInWorkgroupSize, true,
- ast::Builtin::kWorkgroupSize},
BuiltinCase{SpvBuiltInLocalInvocationId, true,
ast::Builtin::kLocalInvocationId},
BuiltinCase{SpvBuiltInLocalInvocationIndex, true,
diff --git a/src/reader/wgsl/lexer.cc b/src/reader/wgsl/lexer.cc
index 37bfeb5..4d2a184 100644
--- a/src/reader/wgsl/lexer.cc
+++ b/src/reader/wgsl/lexer.cc
@@ -729,7 +729,8 @@
return {Token::Type::kVoid, source, "void"};
if (str == "workgroup")
return {Token::Type::kWorkgroup, source, "workgroup"};
-
+ if (str == "workgroup_size")
+ return {Token::Type::kWorkgroupSize, source, "workgroup_size"};
return {};
}
diff --git a/src/reader/wgsl/lexer_test.cc b/src/reader/wgsl/lexer_test.cc
index d21f3b6..b722246 100644
--- a/src/reader/wgsl/lexer_test.cc
+++ b/src/reader/wgsl/lexer_test.cc
@@ -542,7 +542,8 @@
TokenData{"vec4", Token::Type::kVec4},
TokenData{"vertex", Token::Type::kVertex},
TokenData{"void", Token::Type::kVoid},
- TokenData{"workgroup", Token::Type::kWorkgroup}));
+ TokenData{"workgroup", Token::Type::kWorkgroup},
+ TokenData{"workgroup_size", Token::Type::kWorkgroupSize}));
using KeywordTest_Reserved = testing::TestWithParam<const char*>;
TEST_P(KeywordTest_Reserved, Parses) {
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index ae83f80..c3f79f3 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -63,6 +63,7 @@
#include "src/ast/unary_op.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
+#include "src/ast/workgroup_decoration.h"
#include "src/reader/wgsl/lexer.h"
#include "src/type_manager.h"
@@ -95,9 +96,6 @@
if (str == "frag_depth") {
return ast::Builtin::kFragDepth;
}
- if (str == "workgroup_size") {
- return ast::Builtin::kWorkgroupSize;
- }
if (str == "local_invocation_id") {
return ast::Builtin::kLocalInvocationId;
}
@@ -110,6 +108,14 @@
return ast::Builtin::kNone;
}
+bool IsVariableDecoration(Token t) {
+ return t.IsLocation() || t.IsBuiltin() || t.IsBinding() || t.IsSet();
+}
+
+bool IsFunctionDecoration(Token t) {
+ return t.IsWorkgroupSize();
+}
+
} // namespace
ParserImpl::ParserImpl(Context* ctx, const std::string& input)
@@ -444,18 +450,25 @@
if (!t.IsAttrLeft())
return decos;
+ // Check the empty list before verifying the contents
+ t = peek(1);
+ if (t.IsAttrRight()) {
+ set_error(t, "empty variable decoration list");
+ return {};
+ }
+
+ // Make sure we're looking at variable decorations not some other kind
+ if (!IsVariableDecoration(peek(1))) {
+ return decos;
+ }
+
next(); // consume the peek
auto deco = variable_decoration();
if (has_error())
return {};
if (deco == nullptr) {
- t = peek();
- if (t.IsAttrRight()) {
- set_error(t, "empty variable decoration list");
- return {};
- }
- set_error(t, "missing variable decoration for decoration list");
+ set_error(peek(), "missing variable decoration for decoration list");
return {};
}
for (;;) {
@@ -1738,13 +1751,29 @@
}
// function_decl
-// : function_header body_stmt
+// : function_decoration_decl* function_header body_stmt
std::unique_ptr<ast::Function> ParserImpl::function_decl() {
+ ast::FunctionDecorationList decos;
+ for (;;) {
+ size_t s = decos.size();
+ if (!function_decoration_decl(decos)) {
+ return nullptr;
+ }
+ if (decos.size() == s) {
+ break;
+ }
+ }
+
auto f = function_header();
if (has_error())
return nullptr;
- if (f == nullptr)
+ if (f == nullptr) {
+ if (decos.size() > 0) {
+ set_error(peek(), "error parsing function declaration");
+ }
return nullptr;
+ }
+ f->set_decorations(std::move(decos));
auto body = body_stmt();
if (has_error())
@@ -1754,6 +1783,131 @@
return f;
}
+// function_decoration_decl
+// : ATTR_LEFT (function_decoration COMMA)* function_decoration ATTR_RIGHT
+bool ParserImpl::function_decoration_decl(ast::FunctionDecorationList& decos) {
+ auto t = peek();
+ if (!t.IsAttrLeft()) {
+ return true;
+ }
+ // Handle error on empty attributes before the type check
+ t = peek(1);
+ if (t.IsAttrRight()) {
+ set_error(t, "missing decorations for function decoration block");
+ return false;
+ }
+
+ // Make sure we're looking at function decorations and not some other kind
+ if (!IsFunctionDecoration(peek(1))) {
+ return true;
+ }
+
+ next(); // Consume the peek
+
+ size_t count = 0;
+ for (;;) {
+ auto deco = function_decoration();
+ if (has_error()) {
+ return false;
+ }
+ if (deco == nullptr) {
+ set_error(peek(), "expected decoration but none found");
+ return false;
+ }
+ decos.push_back(std::move(deco));
+ count++;
+
+ t = peek();
+ if (!t.IsComma()) {
+ break;
+ }
+ next(); // Consume the peek
+ }
+ if (count == 0) {
+ set_error(peek(), "missing decorations for function decoration block");
+ return false;
+ }
+
+ t = next();
+ if (!t.IsAttrRight()) {
+ set_error(t, "missing ]] for function decorations");
+ return false;
+ }
+ return true;
+}
+
+// function_decoration
+// : TODO(dsinclair) STAGE PAREN_LEFT pipeline_stage PAREN_RIGHT
+// | WORKGROUP_SIZE PAREN_LEFT INT_LITERAL
+// (COMMA INT_LITERAL (COMMA INT_LITERAL)?)? PAREN_RIGHT
+std::unique_ptr<ast::FunctionDecoration> ParserImpl::function_decoration() {
+ auto t = peek();
+ if (t.IsWorkgroupSize()) {
+ next(); // Consume the peek
+
+ t = next();
+ if (!t.IsParenLeft()) {
+ set_error(t, "missing ( for workgroup_size");
+ return nullptr;
+ }
+
+ t = next();
+ if (!t.IsSintLiteral()) {
+ set_error(t, "missing x value for workgroup_size");
+ return nullptr;
+ }
+ if (t.to_i32() <= 0) {
+ set_error(t, "invalid value for workgroup_size x parameter");
+ return nullptr;
+ }
+ int32_t x = t.to_i32();
+ int32_t y = 1;
+ int32_t z = 1;
+
+ t = peek();
+ if (t.IsComma()) {
+ next(); // Consume the peek
+
+ t = next();
+ if (!t.IsSintLiteral()) {
+ set_error(t, "missing y value for workgroup_size");
+ return nullptr;
+ }
+ if (t.to_i32() <= 0) {
+ set_error(t, "invalid value for workgroup_size y parameter");
+ return nullptr;
+ }
+ y = t.to_i32();
+
+ t = peek();
+ if (t.IsComma()) {
+ next(); // Consume the peek
+
+ t = next();
+ if (!t.IsSintLiteral()) {
+ set_error(t, "missing z value for workgroup_size");
+ return nullptr;
+ }
+ if (t.to_i32() <= 0) {
+ set_error(t, "invalid value for workgroup_size z parameter");
+ return nullptr;
+ }
+ z = t.to_i32();
+ }
+ }
+
+ t = next();
+ if (!t.IsParenRight()) {
+ set_error(t, "missing ) for workgroup_size");
+ return nullptr;
+ }
+
+ return std::make_unique<ast::WorkgroupDecoration>(uint32_t(x), uint32_t(y),
+ uint32_t(z));
+ }
+ return nullptr;
+}
+
// function_type_decl
// : type_decl
// | VOID
diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h
index 6da84df..578b4fd 100644
--- a/src/reader/wgsl/parser_impl.h
+++ b/src/reader/wgsl/parser_impl.h
@@ -182,6 +182,13 @@
/// Parses a `function_decl` grammar element
/// @returns the parsed function, nullptr otherwise
std::unique_ptr<ast::Function> function_decl();
+ /// Parses a `function_decoration_decl` grammar element
+ /// @param decos list to store the parsed decorations
+ /// @returns true on successful parse; false otherwise
+ bool function_decoration_decl(ast::FunctionDecorationList& decos);
+ /// Parses a `function_decoration` grammar element
+ /// @returns the parsed decoration, nullptr otherwise
+ std::unique_ptr<ast::FunctionDecoration> function_decoration();
/// Parses a `texture_sampler_types` grammar element
/// @returns the parsed Type or nullptr if none matched.
ast::type::Type* texture_sampler_types();
diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc
index fe8e3c7..3ae5905 100644
--- a/src/reader/wgsl/parser_impl_function_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decl_test.cc
@@ -15,6 +15,7 @@
#include "gtest/gtest.h"
#include "src/ast/function.h"
#include "src/ast/type/type.h"
+#include "src/ast/workgroup_decoration.h"
#include "src/reader/wgsl/parser_impl.h"
#include "src/reader/wgsl/parser_impl_test_helper.h"
@@ -45,6 +46,113 @@
EXPECT_TRUE(body->get(0)->IsReturn());
}
+TEST_F(ParserImplTest, FunctionDecl_DecorationList) {
+ auto* p = parser("[[workgroup_size(2, 3, 4)]] fn main() -> void { return; }");
+ auto f = p->function_decl();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(f, nullptr);
+
+ EXPECT_EQ(f->name(), "main");
+ ASSERT_NE(f->return_type(), nullptr);
+ EXPECT_TRUE(f->return_type()->IsVoid());
+ ASSERT_EQ(f->params().size(), 0u);
+ ASSERT_NE(f->return_type(), nullptr);
+ EXPECT_TRUE(f->return_type()->IsVoid());
+
+ auto& decos = f->decorations();
+ ASSERT_EQ(decos.size(), 1u);
+ ASSERT_TRUE(decos[0]->IsWorkgroup());
+
+ uint32_t x = 0;
+ uint32_t y = 0;
+ uint32_t z = 0;
+ std::tie(x, y, z) = decos[0]->AsWorkgroup()->values();
+ EXPECT_EQ(x, 2u);
+ EXPECT_EQ(y, 3u);
+ EXPECT_EQ(z, 4u);
+
+ auto* body = f->body();
+ ASSERT_EQ(body->size(), 1u);
+ EXPECT_TRUE(body->get(0)->IsReturn());
+}
+
+TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleEntries) {
+ auto* p = parser(R"(
+[[workgroup_size(2, 3, 4), workgroup_size(5, 6, 7)]]
+fn main() -> void { return; })");
+ auto f = p->function_decl();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(f, nullptr);
+
+ EXPECT_EQ(f->name(), "main");
+ ASSERT_NE(f->return_type(), nullptr);
+ EXPECT_TRUE(f->return_type()->IsVoid());
+ ASSERT_EQ(f->params().size(), 0u);
+ ASSERT_NE(f->return_type(), nullptr);
+ EXPECT_TRUE(f->return_type()->IsVoid());
+
+ auto& decos = f->decorations();
+ ASSERT_EQ(decos.size(), 2u);
+
+ uint32_t x = 0;
+ uint32_t y = 0;
+ uint32_t z = 0;
+ ASSERT_TRUE(decos[0]->IsWorkgroup());
+ std::tie(x, y, z) = decos[0]->AsWorkgroup()->values();
+ EXPECT_EQ(x, 2u);
+ EXPECT_EQ(y, 3u);
+ EXPECT_EQ(z, 4u);
+
+ ASSERT_TRUE(decos[1]->IsWorkgroup());
+ std::tie(x, y, z) = decos[1]->AsWorkgroup()->values();
+ EXPECT_EQ(x, 5u);
+ EXPECT_EQ(y, 6u);
+ EXPECT_EQ(z, 7u);
+
+ auto* body = f->body();
+ ASSERT_EQ(body->size(), 1u);
+ EXPECT_TRUE(body->get(0)->IsReturn());
+}
+
+TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleLists) {
+ auto* p = parser(R"(
+[[workgroup_size(2, 3, 4)]]
+[[workgroup_size(5, 6, 7)]]
+fn main() -> void { return; })");
+ auto f = p->function_decl();
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_NE(f, nullptr);
+
+ EXPECT_EQ(f->name(), "main");
+ ASSERT_NE(f->return_type(), nullptr);
+ EXPECT_TRUE(f->return_type()->IsVoid());
+ ASSERT_EQ(f->params().size(), 0u);
+ ASSERT_NE(f->return_type(), nullptr);
+ EXPECT_TRUE(f->return_type()->IsVoid());
+
+ auto& decos = f->decorations();
+ ASSERT_EQ(decos.size(), 2u);
+
+ uint32_t x = 0;
+ uint32_t y = 0;
+ uint32_t z = 0;
+ ASSERT_TRUE(decos[0]->IsWorkgroup());
+ std::tie(x, y, z) = decos[0]->AsWorkgroup()->values();
+ EXPECT_EQ(x, 2u);
+ EXPECT_EQ(y, 3u);
+ EXPECT_EQ(z, 4u);
+
+ ASSERT_TRUE(decos[1]->IsWorkgroup());
+ std::tie(x, y, z) = decos[1]->AsWorkgroup()->values();
+ EXPECT_EQ(x, 5u);
+ EXPECT_EQ(y, 6u);
+ EXPECT_EQ(z, 7u);
+
+ auto* body = f->body();
+ ASSERT_EQ(body->size(), 1u);
+ EXPECT_TRUE(body->get(0)->IsReturn());
+}
+
TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) {
auto* p = parser("fn main() -> { }");
auto f = p->function_decl();
diff --git a/src/reader/wgsl/parser_impl_function_decoration_list_test.cc b/src/reader/wgsl/parser_impl_function_decoration_list_test.cc
new file mode 100644
index 0000000..9717eed
--- /dev/null
+++ b/src/reader/wgsl/parser_impl_function_decoration_list_test.cc
@@ -0,0 +1,98 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest.h"
+#include "src/ast/workgroup_decoration.h"
+#include "src/reader/wgsl/parser_impl.h"
+#include "src/reader/wgsl/parser_impl_test_helper.h"
+
+namespace tint {
+namespace reader {
+namespace wgsl {
+namespace {
+
+TEST_F(ParserImplTest, FunctionDecorationList_Parses) {
+ auto* p = parser("[[workgroup_size(2), workgroup_size(3, 4, 5)]]");
+ ast::FunctionDecorationList decos;
+ ASSERT_TRUE(p->function_decoration_decl(decos));
+ ASSERT_FALSE(p->has_error()) << p->error();
+ ASSERT_EQ(decos.size(), 2u);
+
+ uint32_t x = 0;
+ uint32_t y = 0;
+ uint32_t z = 0;
+ ASSERT_TRUE(decos[0]->IsWorkgroup());
+ std::tie(x, y, z) = decos[0]->AsWorkgroup()->values();
+ EXPECT_EQ(x, 2u);
+
+ ASSERT_TRUE(decos[1]->IsWorkgroup());
+ std::tie(x, y, z) = decos[1]->AsWorkgroup()->values();
+ EXPECT_EQ(x, 3u);
+ EXPECT_EQ(y, 4u);
+ EXPECT_EQ(z, 5u);
+}
+
+TEST_F(ParserImplTest, FunctionDecorationList_Empty) {
+ auto* p = parser("[[]]");
+ ast::FunctionDecorationList decos;
+ ASSERT_FALSE(p->function_decoration_decl(decos));
+ ASSERT_TRUE(p->has_error());
+ ASSERT_EQ(p->error(),
+ "1:3: missing decorations for function decoration block");
+}
+
+TEST_F(ParserImplTest, FunctionDecorationList_Invalid) {
+ auto* p = parser("[[invalid]]");
+ ast::FunctionDecorationList decos;
+ ASSERT_TRUE(p->function_decoration_decl(decos));
+ ASSERT_FALSE(p->has_error());
+ ASSERT_TRUE(decos.empty());
+}
+
+TEST_F(ParserImplTest, FunctionDecorationList_ExtraComma) {
+ auto* p = parser("[[workgroup_size(2), ]]");
+ ast::FunctionDecorationList decos;
+ ASSERT_FALSE(p->function_decoration_decl(decos));
+ ASSERT_TRUE(p->has_error());
+ ASSERT_EQ(p->error(), "1:22: expected decoration but none found");
+}
+
+TEST_F(ParserImplTest, FunctionDecorationList_MissingComma) {
+ auto* p = parser("[[workgroup_size(2) workgroup_size(2)]]");
+ ast::FunctionDecorationList decos;
+ ASSERT_FALSE(p->function_decoration_decl(decos));
+ ASSERT_TRUE(p->has_error());
+ ASSERT_EQ(p->error(), "1:21: missing ]] for function decorations");
+}
+
+TEST_F(ParserImplTest, FunctionDecorationList_BadDecoration) {
+ auto* p = parser("[[workgroup_size()]]");
+ ast::FunctionDecorationList decos;
+ ASSERT_FALSE(p->function_decoration_decl(decos));
+ ASSERT_TRUE(p->has_error());
+ ASSERT_EQ(p->error(), "1:18: missing x value for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecorationList_MissingRightAttr) {
+ auto* p = parser("[[workgroup_size(2), workgroup_size(3, 4, 5)");
+ ast::FunctionDecorationList decos;
+ ASSERT_FALSE(p->function_decoration_decl(decos));
+ ASSERT_TRUE(p->has_error());
+ ASSERT_EQ(p->error(), "1:45: missing ]] for function decorations");
+}
+
+} // namespace
+} // namespace wgsl
+} // namespace reader
+} // namespace tint
diff --git a/src/reader/wgsl/parser_impl_function_decoration_test.cc b/src/reader/wgsl/parser_impl_function_decoration_test.cc
new file mode 100644
index 0000000..c7acb60
--- /dev/null
+++ b/src/reader/wgsl/parser_impl_function_decoration_test.cc
@@ -0,0 +1,196 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest.h"
+#include "src/ast/workgroup_decoration.h"
+#include "src/reader/wgsl/parser_impl.h"
+#include "src/reader/wgsl/parser_impl_test_helper.h"
+
+namespace tint {
+namespace reader {
+namespace wgsl {
+namespace {
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup) {
+ auto* p = parser("workgroup_size(4)");
+ auto deco = p->function_decoration();
+ ASSERT_NE(deco, nullptr);
+ ASSERT_FALSE(p->has_error());
+ ASSERT_TRUE(deco->IsWorkgroup());
+
+ uint32_t x = 0;
+ uint32_t y = 0;
+ uint32_t z = 0;
+ std::tie(x, y, z) = deco->AsWorkgroup()->values();
+ EXPECT_EQ(x, 4u);
+ EXPECT_EQ(y, 1u);
+ EXPECT_EQ(z, 1u);
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_2Param) {
+ auto* p = parser("workgroup_size(4, 5)");
+ auto deco = p->function_decoration();
+ ASSERT_NE(deco, nullptr) << p->error();
+ ASSERT_FALSE(p->has_error());
+ ASSERT_TRUE(deco->IsWorkgroup());
+
+ uint32_t x = 0;
+ uint32_t y = 0;
+ uint32_t z = 0;
+ std::tie(x, y, z) = deco->AsWorkgroup()->values();
+ EXPECT_EQ(x, 4u);
+ EXPECT_EQ(y, 5u);
+ EXPECT_EQ(z, 1u);
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_3Param) {
+ auto* p = parser("workgroup_size(4, 5, 6)");
+ auto deco = p->function_decoration();
+ ASSERT_NE(deco, nullptr);
+ ASSERT_FALSE(p->has_error());
+ ASSERT_TRUE(deco->IsWorkgroup());
+
+ uint32_t x = 0;
+ uint32_t y = 0;
+ uint32_t z = 0;
+ std::tie(x, y, z) = deco->AsWorkgroup()->values();
+ EXPECT_EQ(x, 4u);
+ EXPECT_EQ(y, 5u);
+ EXPECT_EQ(z, 6u);
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_TooManyValues) {
+ auto* p = parser("workgroup_size(1, 2, 3, 4)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:23: missing ) for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Invalid_X_Value) {
+ auto* p = parser("workgroup_size(-2, 5, 6)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:16: invalid value for workgroup_size x parameter");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Invalid_Y_Value) {
+ auto* p = parser("workgroup_size(4, 0, 6)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:19: invalid value for workgroup_size y parameter");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Invalid_Z_Value) {
+ auto* p = parser("workgroup_size(4, 5, -3)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:22: invalid value for workgroup_size z parameter");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_MissingLeftParam) {
+ auto* p = parser("workgroup_size 4, 5, 6)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:16: missing ( for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_MissingRightParam) {
+ auto* p = parser("workgroup_size(4, 5, 6");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:23: missing ) for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_MissingValues) {
+ auto* p = parser("workgroup_size()");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:16: missing x value for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Missing_X_Value) {
+ auto* p = parser("workgroup_size(, 2, 3)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:16: missing x value for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Missing_Y_Comma) {
+ auto* p = parser("workgroup_size(1 2, 3)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:18: missing ) for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Missing_Y_Value) {
+ auto* p = parser("workgroup_size(1, , 3)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:19: missing y value for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Missing_Z_Comma) {
+ auto* p = parser("workgroup_size(1, 2 3)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:21: missing ) for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Missing_Z_Value) {
+ auto* p = parser("workgroup_size(1, 2, )");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:22: missing z value for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Missing_X_Invalid) {
+ auto* p = parser("workgroup_size(nan)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:16: missing x value for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Missing_Y_Invalid) {
+ auto* p = parser("workgroup_size(2, nan)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:19: missing y value for workgroup_size");
+}
+
+TEST_F(ParserImplTest, FunctionDecoration_Workgroup_Missing_Z_Invalid) {
+ auto* p = parser("workgroup_size(2, 3, nan)");
+ auto deco = p->function_decoration();
+ ASSERT_EQ(deco, nullptr);
+ ASSERT_TRUE(p->has_error());
+ EXPECT_EQ(p->error(), "1:22: missing z value for workgroup_size");
+}
+
+} // namespace
+} // namespace wgsl
+} // namespace reader
+} // namespace tint
diff --git a/src/reader/wgsl/parser_impl_global_decl_test.cc b/src/reader/wgsl/parser_impl_global_decl_test.cc
index 489f07d..c04735b 100644
--- a/src/reader/wgsl/parser_impl_global_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_global_decl_test.cc
@@ -166,6 +166,16 @@
EXPECT_EQ(m.functions()[0]->name(), "main");
}
+TEST_F(ParserImplTest, GlobalDecl_Function_WithDecoration) {
+ auto* p = parser("[[workgroup_size(2)]] fn main() -> void { return; }");
+ p->global_decl();
+ ASSERT_FALSE(p->has_error()) << p->error();
+
+ auto m = p->module();
+ ASSERT_EQ(m.functions().size(), 1u);
+ EXPECT_EQ(m.functions()[0]->name(), "main");
+}
+
TEST_F(ParserImplTest, GlobalDecl_Function_Invalid) {
auto* p = parser("fn main() -> { return; }");
p->global_decl();
diff --git a/src/reader/wgsl/parser_impl_variable_decoration_list_test.cc b/src/reader/wgsl/parser_impl_variable_decoration_list_test.cc
index 33719ca..942abe9 100644
--- a/src/reader/wgsl/parser_impl_variable_decoration_list_test.cc
+++ b/src/reader/wgsl/parser_impl_variable_decoration_list_test.cc
@@ -44,8 +44,8 @@
TEST_F(ParserImplTest, VariableDecorationList_Invalid) {
auto* p = parser(R"([[invalid]])");
auto decos = p->variable_decoration_list();
- ASSERT_TRUE(p->has_error());
- ASSERT_EQ(p->error(), "1:3: missing variable decoration for decoration list");
+ ASSERT_FALSE(p->has_error());
+ ASSERT_TRUE(decos.empty());
}
TEST_F(ParserImplTest, VariableDecorationList_ExtraComma) {
diff --git a/src/reader/wgsl/parser_impl_variable_decoration_test.cc b/src/reader/wgsl/parser_impl_variable_decoration_test.cc
index 070e8ec..d410b6d 100644
--- a/src/reader/wgsl/parser_impl_variable_decoration_test.cc
+++ b/src/reader/wgsl/parser_impl_variable_decoration_test.cc
@@ -117,7 +117,6 @@
BuiltinData{"front_facing", ast::Builtin::kFrontFacing},
BuiltinData{"frag_coord", ast::Builtin::kFragCoord},
BuiltinData{"frag_depth", ast::Builtin::kFragDepth},
- BuiltinData{"workgroup_size", ast::Builtin::kWorkgroupSize},
BuiltinData{"local_invocation_id", ast::Builtin::kLocalInvocationId},
BuiltinData{"local_invocation_idx", ast::Builtin::kLocalInvocationIdx},
BuiltinData{"global_invocation_id",
diff --git a/src/reader/wgsl/token.cc b/src/reader/wgsl/token.cc
index dd1ad3e..ed90e8a 100644
--- a/src/reader/wgsl/token.cc
+++ b/src/reader/wgsl/token.cc
@@ -349,6 +349,8 @@
return "void";
case Token::Type::kWorkgroup:
return "workgroup";
+ case Token::Type::kWorkgroupSize:
+ return "workgroup_size";
}
return "<unknown>";
diff --git a/src/reader/wgsl/token.h b/src/reader/wgsl/token.h
index da18df3..0967cd7 100644
--- a/src/reader/wgsl/token.h
+++ b/src/reader/wgsl/token.h
@@ -359,7 +359,9 @@
/// A 'void'
kVoid,
/// A 'workgroup'
- kWorkgroup
+ kWorkgroup,
+ /// A 'workgroup_size'
+ kWorkgroupSize,
};
/// Converts a token type to a name
@@ -777,6 +779,8 @@
bool IsVoid() const { return type_ == Type::kVoid; }
/// @returns true if token is a 'workgroup'
bool IsWorkgroup() const { return type_ == Type::kWorkgroup; }
+ /// @returns true if token is a 'workgroup_size'
+ bool IsWorkgroupSize() const { return type_ == Type::kWorkgroupSize; }
/// @returns the source line of the token
size_t line() const { return source_.line; }
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index d34e572..f81e36c 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1395,11 +1395,6 @@
return "SV_Position";
case ast::Builtin::kFragDepth:
return "SV_Depth";
- // TODO(dsinclair): Ignore for now. This has been removed as a builtin
- // in the spec. Need to update Tint to match.
- // https://github.com/gpuweb/gpuweb/pull/824
- case ast::Builtin::kWorkgroupSize:
- return "";
case ast::Builtin::kLocalInvocationId:
return "SV_GroupThreadID";
case ast::Builtin::kLocalInvocationIdx:
diff --git a/src/writer/hlsl/generator_impl_test.cc b/src/writer/hlsl/generator_impl_test.cc
index 6b2324a..39b39dc 100644
--- a/src/writer/hlsl/generator_impl_test.cc
+++ b/src/writer/hlsl/generator_impl_test.cc
@@ -88,7 +88,6 @@
HlslBuiltinData{ast::Builtin::kFrontFacing, "SV_IsFrontFacing"},
HlslBuiltinData{ast::Builtin::kFragCoord, "SV_Position"},
HlslBuiltinData{ast::Builtin::kFragDepth, "SV_Depth"},
- HlslBuiltinData{ast::Builtin::kWorkgroupSize, ""},
HlslBuiltinData{ast::Builtin::kLocalInvocationId, "SV_GroupThreadID"},
HlslBuiltinData{ast::Builtin::kLocalInvocationIdx, "SV_GroupIndex"},
HlslBuiltinData{ast::Builtin::kGlobalInvocationId,
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 6bc21fd..2251e3b 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -1257,11 +1257,6 @@
return "position";
case ast::Builtin::kFragDepth:
return "depth(any)";
- // TODO(dsinclair): Ignore for now. This has been removed as a builtin
- // in the spec. Need to update Tint to match.
- // https://github.com/gpuweb/gpuweb/pull/824
- case ast::Builtin::kWorkgroupSize:
- return "";
case ast::Builtin::kLocalInvocationId:
return "thread_position_in_threadgroup";
case ast::Builtin::kLocalInvocationIdx:
diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc
index 922ca9d..f57bf52 100644
--- a/src/writer/msl/generator_impl_test.cc
+++ b/src/writer/msl/generator_impl_test.cc
@@ -116,7 +116,6 @@
MslBuiltinData{ast::Builtin::kFrontFacing, "front_facing"},
MslBuiltinData{ast::Builtin::kFragCoord, "position"},
MslBuiltinData{ast::Builtin::kFragDepth, "depth(any)"},
- MslBuiltinData{ast::Builtin::kWorkgroupSize, ""},
MslBuiltinData{ast::Builtin::kLocalInvocationId,
"thread_position_in_threadgroup"},
MslBuiltinData{ast::Builtin::kLocalInvocationIdx,
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 879ded3..050fd13 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -2309,8 +2309,6 @@
return SpvBuiltInFragCoord;
case ast::Builtin::kFragDepth:
return SpvBuiltInFragDepth;
- case ast::Builtin::kWorkgroupSize:
- return SpvBuiltInWorkgroupSize;
case ast::Builtin::kLocalInvocationId:
return SpvBuiltInLocalInvocationId;
case ast::Builtin::kLocalInvocationIdx:
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 5ab3ce2..1e60c0f 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -389,7 +389,6 @@
BuiltinData{ast::Builtin::kFrontFacing, SpvBuiltInFrontFacing},
BuiltinData{ast::Builtin::kFragCoord, SpvBuiltInFragCoord},
BuiltinData{ast::Builtin::kFragDepth, SpvBuiltInFragDepth},
- BuiltinData{ast::Builtin::kWorkgroupSize, SpvBuiltInWorkgroupSize},
BuiltinData{ast::Builtin::kLocalInvocationId,
SpvBuiltInLocalInvocationId},
BuiltinData{ast::Builtin::kLocalInvocationIdx,
diff --git a/test/function.wgsl b/test/function.wgsl
index e8062f6..b7060fb 100644
--- a/test/function.wgsl
+++ b/test/function.wgsl
@@ -15,3 +15,9 @@
fn main() -> f32 {
return ((2. * 3.) - 4.) / 5.;
}
+
+[[workgroup_size(2)]]
+fn ep() -> void {
+ return;
+}
+entry_point compute = ep;