ast: Support non-literal workgroup_size parameters
Change the type of the values in an ast::WorkgroupDecoration to be
ast::Expression nodes, so that they can represent both
ast::ScalarExpression (literal) and ast::IdentifierExpression
(module-scope constant).
The Resolver processes these nodes to produce a uint32_t for the
default value on each dimension, and captures a reference to the
module-scope constant if it is overridable (which will soon be used by
the inspector and backends).
The WGSL parser now uses `primary_expression` to parse arguments to
workgroup_size.
Also added some WorkgroupSize() helpers to ProgramBuilder.
Bug: tint:713
Change-Id: I44b7b0021b925c84f25f65e26dc7da6b19ede508
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51262
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 44ba686..a634e6f 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -107,7 +107,7 @@
ProgramBuilder b2;
b1.Func("func", VariableList{}, b1.ty.void_(), StatementList{},
DecorationList{
- b2.create<WorkgroupDecoration>(2, 4, 6),
+ b2.WorkgroupSize(2, 4, 6),
});
},
"internal compiler error");
@@ -121,7 +121,7 @@
b1.Func("func", VariableList{}, b1.ty.void_(), StatementList{},
DecorationList{},
DecorationList{
- b2.create<WorkgroupDecoration>(2, 4, 6),
+ b2.WorkgroupSize(2, 4, 6),
});
},
"internal compiler error");
@@ -159,10 +159,14 @@
StatementList{
create<DiscardStatement>(),
},
- DecorationList{create<WorkgroupDecoration>(2, 4, 6)});
+ DecorationList{WorkgroupSize(2, 4, 6)});
EXPECT_EQ(str(f), R"(Function func -> __void
-WorkgroupDecoration{2 4 6}
+WorkgroupDecoration{
+ ScalarConstructor[not set]{2}
+ ScalarConstructor[not set]{4}
+ ScalarConstructor[not set]{6}
+}
()
{
Discard{}
diff --git a/src/ast/int_literal.h b/src/ast/int_literal.h
index 1b5c7ef..4a12ef6 100644
--- a/src/ast/int_literal.h
+++ b/src/ast/int_literal.h
@@ -25,6 +25,9 @@
public:
~IntLiteral() override;
+ /// @returns the literal value as an i32
+ int32_t value_as_i32() const { return static_cast<int32_t>(value_); }
+
/// @returns the literal value as a u32
uint32_t value_as_u32() const { return value_; }
diff --git a/src/ast/workgroup_decoration.cc b/src/ast/workgroup_decoration.cc
index ba7d3ff..36a0d40 100644
--- a/src/ast/workgroup_decoration.cc
+++ b/src/ast/workgroup_decoration.cc
@@ -23,36 +23,36 @@
WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
const Source& source,
- uint32_t x)
- : WorkgroupDecoration(program_id, source, x, 1, 1) {}
-
-WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
- const Source& source,
- uint32_t x,
- uint32_t y)
- : WorkgroupDecoration(program_id, source, x, y, 1) {}
-
-WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
- const Source& source,
- uint32_t x,
- uint32_t y,
- uint32_t z)
+ ast::Expression* x,
+ ast::Expression* y,
+ ast::Expression* z)
: Base(program_id, source), x_(x), y_(y), z_(z) {}
WorkgroupDecoration::~WorkgroupDecoration() = default;
-void WorkgroupDecoration::to_str(const sem::Info&,
+void WorkgroupDecoration::to_str(const sem::Info& sem,
std::ostream& out,
size_t indent) const {
make_indent(out, indent);
- out << "WorkgroupDecoration{" << x_ << " " << y_ << " " << z_ << "}"
- << std::endl;
+ out << "WorkgroupDecoration{" << std::endl;
+ x_->to_str(sem, out, indent + 2);
+ if (y_) {
+ y_->to_str(sem, out, indent + 2);
+ if (z_) {
+ z_->to_str(sem, out, indent + 2);
+ }
+ }
+ make_indent(out, indent);
+ out << "}" << std::endl;
}
WorkgroupDecoration* WorkgroupDecoration::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source());
- return ctx->dst->create<WorkgroupDecoration>(src, x_, y_, z_);
+ auto* x = ctx->Clone(x_);
+ auto* y = ctx->Clone(y_);
+ auto* z = ctx->Clone(z_);
+ return ctx->dst->create<WorkgroupDecoration>(src, x, y, z);
}
} // namespace ast
diff --git a/src/ast/workgroup_decoration.h b/src/ast/workgroup_decoration.h
index 5838168..fa68e87 100644
--- a/src/ast/workgroup_decoration.h
+++ b/src/ast/workgroup_decoration.h
@@ -15,47 +15,35 @@
#ifndef SRC_AST_WORKGROUP_DECORATION_H_
#define SRC_AST_WORKGROUP_DECORATION_H_
-#include <tuple>
+#include <array>
#include "src/ast/decoration.h"
namespace tint {
namespace ast {
+// Forward declaration
+class Expression;
+
/// A workgroup decoration
class WorkgroupDecoration : public Castable<WorkgroupDecoration, Decoration> {
public:
/// constructor
/// @param program_id the identifier of the program that owns this node
/// @param source the source of this decoration
- /// @param x the workgroup x dimension size
- WorkgroupDecoration(ProgramID program_id, const Source& source, uint32_t x);
- /// constructor
- /// @param program_id the identifier of the program that owns this node
- /// @param source the source of this decoration
- /// @param x the workgroup x dimension size
- /// @param y the workgroup x dimension size
+ /// @param x the workgroup x dimension expression
+ /// @param y the optional workgroup y dimension expression
+ /// @param z the optional workgroup z dimension expression
WorkgroupDecoration(ProgramID program_id,
const Source& source,
- uint32_t x,
- uint32_t y);
- /// constructor
- /// @param program_id the identifier of the program that owns this node
- /// @param source the source of this decoration
- /// @param x the workgroup x dimension size
- /// @param y the workgroup x dimension size
- /// @param z the workgroup x dimension size
- WorkgroupDecoration(ProgramID program_id,
- const Source& source,
- uint32_t x,
- uint32_t y,
- uint32_t z);
+ ast::Expression* x,
+ ast::Expression* y = nullptr,
+ ast::Expression* z = nullptr);
+
~WorkgroupDecoration() override;
/// @returns the workgroup dimensions
- std::tuple<uint32_t, uint32_t, uint32_t> values() const {
- return {x_, y_, z_};
- }
+ std::array<ast::Expression*, 3> values() const { return {x_, y_, z_}; }
/// Outputs the decoration to the given stream
/// @param sem the semantic info for the program
@@ -72,9 +60,9 @@
WorkgroupDecoration* Clone(CloneContext* ctx) const override;
private:
- uint32_t const x_;
- uint32_t const y_;
- uint32_t const z_;
+ ast::Expression* x_ = nullptr;
+ ast::Expression* y_ = nullptr;
+ ast::Expression* z_ = nullptr;
};
} // namespace ast
diff --git a/src/ast/workgroup_decoration_test.cc b/src/ast/workgroup_decoration_test.cc
index c5c697f..1f0756d 100644
--- a/src/ast/workgroup_decoration_test.cc
+++ b/src/ast/workgroup_decoration_test.cc
@@ -24,40 +24,84 @@
using WorkgroupDecorationTest = TestHelper;
TEST_F(WorkgroupDecorationTest, Creation_1param) {
- auto* d = create<WorkgroupDecoration>(2);
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
- std::tie(x, y, z) = d->values();
- EXPECT_EQ(x, 2u);
- EXPECT_EQ(y, 1u);
- EXPECT_EQ(z, 1u);
+ auto* d = WorkgroupSize(2);
+ auto values = d->values();
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_TRUE(x_scalar);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+ EXPECT_EQ(values[1], nullptr);
+ EXPECT_EQ(values[2], nullptr);
}
TEST_F(WorkgroupDecorationTest, Creation_2param) {
- auto* d = create<WorkgroupDecoration>(2, 4);
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
- std::tie(x, y, z) = d->values();
- EXPECT_EQ(x, 2u);
- EXPECT_EQ(y, 4u);
- EXPECT_EQ(z, 1u);
+ auto* d = WorkgroupSize(2, 4);
+ auto values = d->values();
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_TRUE(x_scalar);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+ ASSERT_TRUE(y_scalar);
+ ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ EXPECT_EQ(values[2], nullptr);
}
TEST_F(WorkgroupDecorationTest, Creation_3param) {
- auto* d = create<WorkgroupDecoration>(2, 4, 6);
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
- std::tie(x, y, z) = d->values();
- EXPECT_EQ(x, 2u);
- EXPECT_EQ(y, 4u);
- EXPECT_EQ(z, 6u);
+ auto* d = WorkgroupSize(2, 4, 6);
+ auto values = d->values();
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_TRUE(x_scalar);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+ ASSERT_TRUE(y_scalar);
+ ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ ASSERT_NE(values[2], nullptr);
+ auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+ ASSERT_TRUE(z_scalar);
+ ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 6u);
+}
+
+TEST_F(WorkgroupDecorationTest, Creation_WithIdentifier) {
+ auto* d = WorkgroupSize(2, 4, "depth");
+ auto values = d->values();
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_TRUE(x_scalar);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+ ASSERT_TRUE(y_scalar);
+ ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ ASSERT_NE(values[2], nullptr);
+ auto* z_ident = values[2]->As<ast::IdentifierExpression>();
+ ASSERT_TRUE(z_ident);
+ EXPECT_EQ(Symbols().NameFor(z_ident->symbol()), "depth");
}
TEST_F(WorkgroupDecorationTest, ToStr) {
- auto* d = create<WorkgroupDecoration>(2, 4, 6);
- EXPECT_EQ(str(d), R"(WorkgroupDecoration{2 4 6}
+ auto* d = WorkgroupSize(2, "height");
+ EXPECT_EQ(str(d), R"(WorkgroupDecoration{
+ ScalarConstructor[not set]{2}
+ Identifier[not set]{height}
+}
)");
}
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index 2c5f01f..62d95a7 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -828,10 +828,8 @@
}
TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
- MakeEmptyBodyFunction("foo", ast::DecorationList{
- Stage(ast::PipelineStage::kCompute),
- create<ast::WorkgroupDecoration>(8u, 2u, 1u),
- });
+ MakeEmptyBodyFunction(
+ "foo", {Stage(ast::PipelineStage::kCompute), WorkgroupSize(8, 2, 1)});
Inspector& inspector = Build();
diff --git a/src/program_builder.h b/src/program_builder.h
index 610f1eb..05acee4 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -59,6 +59,7 @@
#include "src/ast/variable_decl_statement.h"
#include "src/ast/vector.h"
#include "src/ast/void.h"
+#include "src/ast/workgroup_decoration.h"
#include "src/program.h"
#include "src/program_id.h"
#include "src/sem/array.h"
@@ -1914,6 +1915,36 @@
return create<ast::StageDecoration>(source_, stage);
}
+ /// Creates an ast::WorkgroupDecoration
+ /// @param x the x dimension expression
+ /// @returns the workgroup decoration pointer
+ template <typename EXPR_X>
+ ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x) {
+ return WorkgroupSize(std::forward<EXPR_X>(x), nullptr, nullptr);
+ }
+
+ /// Creates an ast::WorkgroupDecoration
+ /// @param x the x dimension expression
+ /// @param y the y dimension expression
+ /// @returns the workgroup decoration pointer
+ template <typename EXPR_X, typename EXPR_Y>
+ ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) {
+ return WorkgroupSize(std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y),
+ nullptr);
+ }
+
+ /// Creates an ast::WorkgroupDecoration
+ /// @param x the x dimension expression
+ /// @param y the y dimension expression
+ /// @param z the z dimension expression
+ /// @returns the workgroup decoration pointer
+ template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z>
+ ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) {
+ return create<ast::WorkgroupDecoration>(
+ source_, Expr(std::forward<EXPR_X>(x)), Expr(std::forward<EXPR_Y>(y)),
+ Expr(std::forward<EXPR_Z>(z)));
+ }
+
/// Sets the current builder source to `src`
/// @param src the Source used for future create() calls
void SetSource(const Source& src) {
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 9398034..8c91da6 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -3010,26 +3010,35 @@
if (s == kWorkgroupSizeDecoration) {
return expect_paren_block("workgroup_size decoration", [&]() -> Result {
- uint32_t x;
- uint32_t y = 1;
- uint32_t z = 1;
+ ast::Expression* x = nullptr;
+ ast::Expression* y = nullptr;
+ ast::Expression* z = nullptr;
- auto val = expect_nonzero_positive_sint("workgroup_size x parameter");
- if (val.errored)
+ auto expr = primary_expression();
+ if (expr.errored) {
return Failure::kErrored;
- x = val.value;
+ } else if (!expr.matched) {
+ return add_error(peek(), "expected workgroup_size x parameter");
+ }
+ x = std::move(expr.value);
if (match(Token::Type::kComma)) {
- val = expect_nonzero_positive_sint("workgroup_size y parameter");
- if (val.errored)
+ expr = primary_expression();
+ if (expr.errored) {
return Failure::kErrored;
- y = val.value;
+ } else if (!expr.matched) {
+ return add_error(peek(), "expected workgroup_size y parameter");
+ }
+ y = std::move(expr.value);
if (match(Token::Type::kComma)) {
- val = expect_nonzero_positive_sint("workgroup_size z parameter");
- if (val.errored)
+ expr = primary_expression();
+ if (expr.errored) {
return Failure::kErrored;
- z = val.value;
+ } else if (!expr.matched) {
+ return add_error(peek(), "expected workgroup_size z parameter");
+ }
+ z = std::move(expr.value);
}
}
diff --git a/src/reader/wgsl/parser_impl_error_msg_test.cc b/src/reader/wgsl/parser_impl_error_msg_test.cc
index c9c7c83..ba54f60 100644
--- a/src/reader/wgsl/parser_impl_error_msg_test.cc
+++ b/src/reader/wgsl/parser_impl_error_msg_test.cc
@@ -325,74 +325,23 @@
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXInvalid) {
- EXPECT("[[workgroup_size(x)]] fn f() {}",
- "test.wgsl:1:18 error: expected signed integer literal for "
- "workgroup_size x parameter\n"
- "[[workgroup_size(x)]] fn f() {}\n"
- " ^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXNegative) {
- EXPECT("[[workgroup_size(-1)]] fn f() {}",
- "test.wgsl:1:18 error: workgroup_size x parameter must be greater "
- "than 0\n"
- "[[workgroup_size(-1)]] fn f() {}\n"
- " ^^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXZero) {
- EXPECT("[[workgroup_size(0)]] fn f() {}",
- "test.wgsl:1:18 error: workgroup_size x parameter must be greater "
- "than 0\n"
- "[[workgroup_size(0)]] fn f() {}\n"
+ EXPECT("[[workgroup_size(@)]] fn f() {}",
+ "test.wgsl:1:18 error: expected workgroup_size x parameter\n"
+ "[[workgroup_size(@)]] fn f() {}\n"
" ^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYInvalid) {
- EXPECT("[[workgroup_size(1, x)]] fn f() {}",
- "test.wgsl:1:21 error: expected signed integer literal for "
- "workgroup_size y parameter\n"
- "[[workgroup_size(1, x)]] fn f() {}\n"
- " ^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYNegative) {
- EXPECT("[[workgroup_size(1, -1)]] fn f() {}",
- "test.wgsl:1:21 error: workgroup_size y parameter must be greater "
- "than 0\n"
- "[[workgroup_size(1, -1)]] fn f() {}\n"
- " ^^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYZero) {
- EXPECT("[[workgroup_size(1, 0)]] fn f() {}",
- "test.wgsl:1:21 error: workgroup_size y parameter must be greater "
- "than 0\n"
- "[[workgroup_size(1, 0)]] fn f() {}\n"
+ EXPECT("[[workgroup_size(1, @)]] fn f() {}",
+ "test.wgsl:1:21 error: expected workgroup_size y parameter\n"
+ "[[workgroup_size(1, @)]] fn f() {}\n"
" ^\n");
}
TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZInvalid) {
- EXPECT("[[workgroup_size(1, 2, x)]] fn f() {}",
- "test.wgsl:1:24 error: expected signed integer literal for "
- "workgroup_size z parameter\n"
- "[[workgroup_size(1, 2, x)]] fn f() {}\n"
- " ^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZNegative) {
- EXPECT("[[workgroup_size(1, 2, -1)]] fn f() {}",
- "test.wgsl:1:24 error: workgroup_size z parameter must be greater "
- "than 0\n"
- "[[workgroup_size(1, 2, -1)]] fn f() {}\n"
- " ^^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZZero) {
- EXPECT("[[workgroup_size(1, 2, 0)]] fn f() {}",
- "test.wgsl:1:24 error: workgroup_size z parameter must be greater "
- "than 0\n"
- "[[workgroup_size(1, 2, 0)]] fn f() {}\n"
+ EXPECT("[[workgroup_size(1, 2, @)]] fn f() {}",
+ "test.wgsl:1:24 error: expected workgroup_size z parameter\n"
+ "[[workgroup_size(1, 2, @)]] fn f() {}\n"
" ^\n");
}
diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc
index 3144df3..f08db28 100644
--- a/src/reader/wgsl/parser_impl_function_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decl_test.cc
@@ -69,13 +69,25 @@
ASSERT_EQ(decorations.size(), 1u);
ASSERT_TRUE(decorations[0]->Is<ast::WorkgroupDecoration>());
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
- std::tie(x, y, z) = decorations[0]->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 2u);
- EXPECT_EQ(y, 3u);
- EXPECT_EQ(z, 4u);
+ auto values = decorations[0]->As<ast::WorkgroupDecoration>()->values();
+
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(x_scalar, nullptr);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(y_scalar, nullptr);
+ ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
+
+ ASSERT_NE(values[2], nullptr);
+ auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(z_scalar, nullptr);
+ ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
@@ -84,7 +96,7 @@
TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleEntries) {
auto p = parser(R"(
-[[workgroup_size(2, 3, 4), workgroup_size(5, 6, 7)]]
+[[workgroup_size(2, 3, 4), stage(compute)]]
fn main() { return; })");
auto decos = p->decoration_list();
EXPECT_FALSE(p->has_error()) << p->error();
@@ -104,20 +116,30 @@
auto& decorations = f->decorations();
ASSERT_EQ(decorations.size(), 2u);
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
ASSERT_TRUE(decorations[0]->Is<ast::WorkgroupDecoration>());
- std::tie(x, y, z) = decorations[0]->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 2u);
- EXPECT_EQ(y, 3u);
- EXPECT_EQ(z, 4u);
+ auto values = decorations[0]->As<ast::WorkgroupDecoration>()->values();
- ASSERT_TRUE(decorations[1]->Is<ast::WorkgroupDecoration>());
- std::tie(x, y, z) = decorations[1]->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 5u);
- EXPECT_EQ(y, 6u);
- EXPECT_EQ(z, 7u);
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(x_scalar, nullptr);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(y_scalar, nullptr);
+ ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
+
+ ASSERT_NE(values[2], nullptr);
+ auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(z_scalar, nullptr);
+ ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ ASSERT_TRUE(decorations[1]->Is<ast::StageDecoration>());
+ EXPECT_EQ(decorations[1]->As<ast::StageDecoration>()->value(),
+ ast::PipelineStage::kCompute);
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
@@ -127,7 +149,7 @@
TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleLists) {
auto p = parser(R"(
[[workgroup_size(2, 3, 4)]]
-[[workgroup_size(5, 6, 7)]]
+[[stage(compute)]]
fn main() { return; })");
auto decorations = p->decoration_list();
EXPECT_FALSE(p->has_error()) << p->error();
@@ -147,20 +169,30 @@
auto& decos = f->decorations();
ASSERT_EQ(decos.size(), 2u);
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
ASSERT_TRUE(decos[0]->Is<ast::WorkgroupDecoration>());
- std::tie(x, y, z) = decos[0]->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 2u);
- EXPECT_EQ(y, 3u);
- EXPECT_EQ(z, 4u);
+ auto values = decos[0]->As<ast::WorkgroupDecoration>()->values();
- ASSERT_TRUE(decos[1]->Is<ast::WorkgroupDecoration>());
- std::tie(x, y, z) = decos[1]->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 5u);
- EXPECT_EQ(y, 6u);
- EXPECT_EQ(z, 7u);
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(x_scalar, nullptr);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(y_scalar, nullptr);
+ ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
+
+ ASSERT_NE(values[2], nullptr);
+ auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(z_scalar, nullptr);
+ ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ ASSERT_TRUE(decos[1]->Is<ast::StageDecoration>());
+ EXPECT_EQ(decos[1]->As<ast::StageDecoration>()->value(),
+ ast::PipelineStage::kCompute);
auto* body = f->body();
ASSERT_EQ(body->size(), 1u);
diff --git a/src/reader/wgsl/parser_impl_function_decoration_list_test.cc b/src/reader/wgsl/parser_impl_function_decoration_list_test.cc
index c2272c4..0b76ef2 100644
--- a/src/reader/wgsl/parser_impl_function_decoration_list_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decoration_list_test.cc
@@ -21,7 +21,7 @@
namespace {
TEST_F(ParserImplTest, DecorationList_Parses) {
- auto p = parser("[[workgroup_size(2), workgroup_size(3, 4, 5)]]");
+ auto p = parser("[[workgroup_size(2), stage(compute)]]");
auto decos = p->decoration_list();
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_FALSE(decos.errored);
@@ -33,18 +33,17 @@
ASSERT_NE(deco_0, nullptr);
ASSERT_NE(deco_1, nullptr);
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
ASSERT_TRUE(deco_0->Is<ast::WorkgroupDecoration>());
- std::tie(x, y, z) = deco_0->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 2u);
+ ast::Expression* x = deco_0->As<ast::WorkgroupDecoration>()->values()[0];
+ ASSERT_NE(x, nullptr);
+ auto* x_scalar = x->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(x_scalar, nullptr);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
- ASSERT_TRUE(deco_1->Is<ast::WorkgroupDecoration>());
- std::tie(x, y, z) = deco_1->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 3u);
- EXPECT_EQ(y, 4u);
- EXPECT_EQ(z, 5u);
+ ASSERT_TRUE(deco_1->Is<ast::StageDecoration>());
+ EXPECT_EQ(deco_1->As<ast::StageDecoration>()->value(),
+ ast::PipelineStage::kCompute);
}
TEST_F(ParserImplTest, DecorationList_Empty) {
@@ -85,14 +84,12 @@
}
TEST_F(ParserImplTest, DecorationList_BadDecoration) {
- auto p = parser("[[workgroup_size()]]");
+ auto p = parser("[[stage()]]");
auto decos = p->decoration_list();
EXPECT_TRUE(p->has_error());
EXPECT_TRUE(decos.errored);
EXPECT_FALSE(decos.matched);
- EXPECT_EQ(
- p->error(),
- "1:18: expected signed integer literal for workgroup_size x parameter");
+ EXPECT_EQ(p->error(), "1:9: invalid value for stage decoration");
}
TEST_F(ParserImplTest, DecorationList_MissingRightAttr) {
diff --git a/src/reader/wgsl/parser_impl_function_decoration_test.cc b/src/reader/wgsl/parser_impl_function_decoration_test.cc
index dc9fd1f..432f99e 100644
--- a/src/reader/wgsl/parser_impl_function_decoration_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decoration_test.cc
@@ -32,13 +32,16 @@
ASSERT_NE(func_deco, nullptr);
ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
- std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 4u);
- EXPECT_EQ(y, 1u);
- EXPECT_EQ(z, 1u);
+ auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
+
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(x_scalar, nullptr);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ EXPECT_EQ(values[1], nullptr);
+ EXPECT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Decoration_Workgroup_2Param) {
@@ -52,13 +55,21 @@
ASSERT_NE(func_deco, nullptr) << p->error();
ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
- std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 4u);
- EXPECT_EQ(y, 5u);
- EXPECT_EQ(z, 1u);
+ auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
+
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(x_scalar, nullptr);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(y_scalar, nullptr);
+ ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 5u);
+
+ EXPECT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Decoration_Workgroup_3Param) {
@@ -72,13 +83,52 @@
ASSERT_NE(func_deco, nullptr);
ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
- std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
- EXPECT_EQ(x, 4u);
- EXPECT_EQ(y, 5u);
- EXPECT_EQ(z, 6u);
+ auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
+
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(x_scalar, nullptr);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(y_scalar, nullptr);
+ ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 5u);
+
+ ASSERT_NE(values[2], nullptr);
+ auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(z_scalar, nullptr);
+ ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 6u);
+}
+
+TEST_F(ParserImplTest, Decoration_Workgroup_WithIdent) {
+ auto p = parser("workgroup_size(4, height)");
+ auto deco = p->decoration();
+ EXPECT_TRUE(deco.matched);
+ EXPECT_FALSE(deco.errored);
+ ASSERT_NE(deco.value, nullptr) << p->error();
+ ASSERT_FALSE(p->has_error());
+ auto* func_deco = deco.value->As<ast::Decoration>();
+ ASSERT_NE(func_deco, nullptr);
+ ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
+
+ auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
+
+ ASSERT_NE(values[0], nullptr);
+ auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+ ASSERT_NE(x_scalar, nullptr);
+ ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+ EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+ ASSERT_NE(values[1], nullptr);
+ auto* y_ident = values[1]->As<ast::IdentifierExpression>();
+ ASSERT_NE(y_ident, nullptr);
+ EXPECT_EQ(p->builder().Symbols().NameFor(y_ident->symbol()), "height");
+
+ ASSERT_EQ(values[2], nullptr);
}
TEST_F(ParserImplTest, Decoration_Workgroup_TooManyValues) {
@@ -91,39 +141,6 @@
EXPECT_EQ(p->error(), "1:23: expected ')' for workgroup_size decoration");
}
-TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_X_Value) {
- auto p = parser("workgroup_size(-2, 5, 6)");
- auto deco = p->decoration();
- EXPECT_FALSE(deco.matched);
- EXPECT_TRUE(deco.errored);
- EXPECT_EQ(deco.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(),
- "1:16: workgroup_size x parameter must be greater than 0");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_Y_Value) {
- auto p = parser("workgroup_size(4, 0, 6)");
- auto deco = p->decoration();
- EXPECT_FALSE(deco.matched);
- EXPECT_TRUE(deco.errored);
- EXPECT_EQ(deco.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(),
- "1:19: workgroup_size y parameter must be greater than 0");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_Z_Value) {
- auto p = parser("workgroup_size(4, 5, -3)");
- auto deco = p->decoration();
- EXPECT_FALSE(deco.matched);
- EXPECT_TRUE(deco.errored);
- EXPECT_EQ(deco.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(p->error(),
- "1:22: workgroup_size z parameter must be greater than 0");
-}
-
TEST_F(ParserImplTest, Decoration_Workgroup_MissingLeftParam) {
auto p = parser("workgroup_size 4, 5, 6)");
auto deco = p->decoration();
@@ -151,9 +168,7 @@
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(
- p->error(),
- "1:16: expected signed integer literal for workgroup_size x parameter");
+ EXPECT_EQ(p->error(), "1:16: expected workgroup_size x parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_X_Value) {
@@ -163,9 +178,7 @@
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(
- p->error(),
- "1:16: expected signed integer literal for workgroup_size x parameter");
+ EXPECT_EQ(p->error(), "1:16: expected workgroup_size x parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Y_Comma) {
@@ -185,9 +198,7 @@
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(
- p->error(),
- "1:19: expected signed integer literal for workgroup_size y parameter");
+ EXPECT_EQ(p->error(), "1:19: expected workgroup_size y parameter");
}
TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Z_Comma) {
@@ -207,45 +218,7 @@
EXPECT_TRUE(deco.errored);
EXPECT_EQ(deco.value, nullptr);
EXPECT_TRUE(p->has_error());
- EXPECT_EQ(
- p->error(),
- "1:22: expected signed integer literal for workgroup_size z parameter");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Missing_X_Invalid) {
- auto p = parser("workgroup_size(nan)");
- auto deco = p->decoration();
- EXPECT_FALSE(deco.matched);
- EXPECT_TRUE(deco.errored);
- EXPECT_EQ(deco.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(
- p->error(),
- "1:16: expected signed integer literal for workgroup_size x parameter");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Y_Invalid) {
- auto p = parser("workgroup_size(2, nan)");
- auto deco = p->decoration();
- EXPECT_FALSE(deco.matched);
- EXPECT_TRUE(deco.errored);
- EXPECT_EQ(deco.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(
- p->error(),
- "1:19: expected signed integer literal for workgroup_size y parameter");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Z_Invalid) {
- auto p = parser("workgroup_size(2, 3, nan)");
- auto deco = p->decoration();
- EXPECT_FALSE(deco.matched);
- EXPECT_TRUE(deco.errored);
- EXPECT_EQ(deco.value, nullptr);
- EXPECT_TRUE(p->has_error());
- EXPECT_EQ(
- p->error(),
- "1:22: expected signed integer literal for workgroup_size z parameter");
+ EXPECT_EQ(p->error(), "1:22: expected workgroup_size z parameter");
}
TEST_F(ParserImplTest, Decoration_Stage) {
diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc
index aac4784..15f8f7e 100644
--- a/src/resolver/decoration_validation_test.cc
+++ b/src/resolver/decoration_validation_test.cc
@@ -94,7 +94,8 @@
case DecorationKind::kStructBlock:
return {builder.create<ast::StructBlockDecoration>(source)};
case DecorationKind::kWorkgroup:
- return {builder.create<ast::WorkgroupDecoration>(source, 1u, 1u, 1u)};
+ return {
+ builder.create<ast::WorkgroupDecoration>(source, builder.Expr(1))};
case DecorationKind::kBindingAndGroup:
return {builder.create<ast::BindingDecoration>(source, 1u),
builder.create<ast::GroupDecoration>(source, 1u)};
@@ -664,7 +665,7 @@
TEST_F(WorkgroupDecoration, NotAnEntryPoint) {
Func("main", {}, ty.void_(), {},
- {create<ast::WorkgroupDecoration>(Source{{12, 34}}, 1u)});
+ {create<ast::WorkgroupDecoration>(Source{{12, 34}}, Expr(1))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -675,7 +676,7 @@
TEST_F(WorkgroupDecoration, NotAComputeShader) {
Func("main", {}, ty.void_(), {},
{Stage(ast::PipelineStage::kFragment),
- create<ast::WorkgroupDecoration>(Source{{12, 34}}, 1u)});
+ create<ast::WorkgroupDecoration>(Source{{12, 34}}, Expr(1))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
@@ -685,9 +686,8 @@
TEST_F(WorkgroupDecoration, MultipleAttributes) {
Func(Source{{12, 34}}, "main", {}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute),
- create<ast::WorkgroupDecoration>(1u),
- create<ast::WorkgroupDecoration>(2u)});
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1),
+ WorkgroupSize(2)});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc
index 423270b..a542b9d 100644
--- a/src/resolver/function_validation_test.cc
+++ b/src/resolver/function_validation_test.cc
@@ -244,5 +244,127 @@
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
+ // [[stage(compute), workgroup_size(64.0)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(create<ast::ScalarConstructorExpression>(
+ Source{Source::Location{12, 34}}, Literal(64.f)))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size parameter must be a literal i32 or an "
+ "i32 module-scope constant");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) {
+ // [[stage(compute), workgroup_size(-2)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(create<ast::ScalarConstructorExpression>(
+ Source{Source::Location{12, 34}}, Literal(-2)))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Zero) {
+ // [[stage(compute), workgroup_size(0)]
+ // fn main() {}
+
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(create<ast::ScalarConstructorExpression>(
+ Source{Source::Location{12, 34}}, Literal(0)))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_BadType) {
+ // let x = 64.0;
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ GlobalConst("x", ty.f32(), Expr(64.f));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size parameter must be a literal i32 or an "
+ "i32 module-scope constant");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) {
+ // let x = -2;
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ GlobalConst("x", ty.i32(), Expr(-2));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Zero) {
+ // let x = 0;
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ GlobalConst("x", ty.i32(), Expr(0));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+ WorkgroupSize_Const_NestedZeroValueConstructor) {
+ // let x = i32(i32(i32()));
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ GlobalConst("x", ty.i32(),
+ Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32()))));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
+ // var<private> x = 0;
+ // [[stage(compute), workgroup_size(x)]
+ // fn main() {}
+ Global("x", ty.i32(), ast::StorageClass::kPrivate, Expr(64));
+ Func("main", {}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size parameter must be a literal i32 or an "
+ "i32 module-scope constant");
+}
+
} // namespace
} // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 93266aa..227331b 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1295,10 +1295,79 @@
if (auto* workgroup =
ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations())) {
- // TODO(crbug.com/tint/713): Handle non-literals.
- info->workgroup_size[0].value = std::get<0>(workgroup->values());
- info->workgroup_size[1].value = std::get<1>(workgroup->values());
- info->workgroup_size[2].value = std::get<2>(workgroup->values());
+ auto values = workgroup->values();
+ for (int i = 0; i < 3; i++) {
+ // Each argument to this decoration can either be a literal, an
+ // identifier for a module-scope constants, or nullptr if not specified.
+
+ if (!values[i]) {
+ // Not specified, just use the default.
+ continue;
+ }
+
+ Mark(values[i]);
+
+ int32_t value = 0;
+ if (auto* ident = values[i]->As<ast::IdentifierExpression>()) {
+ // We have an identifier of a module-scope constant.
+ if (!Identifier(ident)) {
+ return false;
+ }
+
+ VariableInfo* var;
+ if (!variable_stack_.get(ident->symbol(), &var) ||
+ !(var->declaration->is_const() && var->type->Is<sem::I32>())) {
+ diagnostics_.add_error(
+ "workgroup_size parameter must be a literal i32 or an i32 "
+ "module-scope constant",
+ values[i]->source());
+ return false;
+ }
+
+ // Capture the constant if an [[override]] attribute is present.
+ if (ast::HasDecoration<ast::OverrideDecoration>(
+ var->declaration->decorations())) {
+ info->workgroup_size[i].overridable_const = var->declaration;
+ }
+
+ auto* constructor = var->declaration->constructor();
+ if (constructor) {
+ // Resolve the constructor expression to use as the default value.
+ if (!GetScalarConstExprValue(constructor, &value)) {
+ return false;
+ }
+ } else {
+ // No constructor means this value must be overriden by the user.
+ info->workgroup_size[i].value = 0;
+ continue;
+ }
+ } else if (auto* scalar =
+ values[i]->As<ast::ScalarConstructorExpression>()) {
+ // We have a literal.
+ Mark(scalar->literal());
+
+ if (!scalar->literal()->Is<ast::IntLiteral>()) {
+ diagnostics_.add_error(
+ "workgroup_size parameter must be a literal i32 or an i32 "
+ "module-scope constant",
+ values[i]->source());
+ return false;
+ }
+
+ if (!GetScalarConstExprValue(scalar, &value)) {
+ return false;
+ }
+ }
+
+ // Validate and set the default value for this dimension.
+ if (value < 1) {
+ diagnostics_.add_error(
+ "workgroup_size parameter must be a positive i32 value",
+ values[i]->source());
+ return false;
+ }
+ info->workgroup_size[i].value = value;
+ }
}
if (!ValidateFunction(func, info)) {
@@ -3098,6 +3167,40 @@
return true;
}
+template <typename T>
+bool Resolver::GetScalarConstExprValue(ast::Expression* expr, T* result) {
+ if (auto* type_constructor = expr->As<ast::TypeConstructorExpression>()) {
+ if (type_constructor->values().size() == 0) {
+ // Zero-valued constructor.
+ *result = static_cast<T>(0);
+ return true;
+ } else if (type_constructor->values().size() == 1) {
+ // Recurse into the constructor argument expression.
+ return GetScalarConstExprValue(type_constructor->values()[0], result);
+ } else {
+ TINT_ICE(diagnostics_) << "malformed scalar type constructor";
+ }
+ } else if (auto* scalar = expr->As<ast::ScalarConstructorExpression>()) {
+ // Cast literal to result type.
+ if (auto* int_lit = scalar->literal()->As<ast::IntLiteral>()) {
+ *result = static_cast<T>(int_lit->value_as_u32());
+ return true;
+ } else if (auto* float_lit = scalar->literal()->As<ast::FloatLiteral>()) {
+ *result = static_cast<T>(float_lit->value());
+ return true;
+ } else if (auto* bool_lit = scalar->literal()->As<ast::BoolLiteral>()) {
+ *result = static_cast<T>(bool_lit->IsTrue());
+ return true;
+ } else {
+ TINT_ICE(diagnostics_) << "unhandled scalar constructor";
+ }
+ } else {
+ TINT_ICE(diagnostics_) << "unhandled constant expression";
+ }
+
+ return false;
+}
+
template <typename F>
bool Resolver::BlockScope(const ast::BlockStatement* block, F&& callback) {
auto* sem_block = builder_->Sem().Get<sem::BlockStatement>(block);
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index ad428a8..8a622d2 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -323,6 +323,13 @@
typ::Type type,
const std::string& type_name);
+ /// Resolve the value of a scalar const_expr.
+ /// @param expr the expression
+ /// @param result pointer to the where the result will be stored
+ /// @returns true on success, false on error
+ template <typename T>
+ bool GetScalarConstExprValue(ast::Expression* expr, T* result);
+
/// Constructs a new semantic BlockStatement with the given type and with
/// #current_block_ as its parent, assigns this to #current_block_, and then
/// calls `callback`. The original #current_block_ is restored on exit.
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 109b31a..41e15e2 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -26,6 +26,7 @@
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic_texture_helper_test.h"
#include "src/ast/loop_statement.h"
+#include "src/ast/override_decoration.h"
#include "src/ast/return_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/struct_block_decoration.h"
@@ -889,6 +890,8 @@
}
TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
+ // [[stage(compute)]]
+ // fn main() {}
auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -905,9 +908,11 @@
}
TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
- auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
- {Stage(ast::PipelineStage::kCompute),
- create<ast::WorkgroupDecoration>(8, 2, 3)});
+ // [[stage(compute), workgroup_size(8, 2, 3)]]
+ // fn main() {}
+ auto* func =
+ Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize(8, 2, 3)});
EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -922,6 +927,134 @@
EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
}
+TEST_F(ResolverTest, Function_WorkgroupSize_Consts) {
+ // let width = 16;
+ // let height = 8;
+ // let depth = 2;
+ // [[stage(compute), workgroup_size(width, height, depth)]]
+ // fn main() {}
+ GlobalConst("width", ty.i32(), Expr(16));
+ GlobalConst("height", ty.i32(), Expr(8));
+ GlobalConst("depth", ty.i32(), Expr(2));
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize("width", "height", "depth")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->workgroup_size()[0].value, 16u);
+ EXPECT_EQ(func_sem->workgroup_size()[1].value, 8u);
+ EXPECT_EQ(func_sem->workgroup_size()[2].value, 2u);
+ EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Consts_NestedInitializer) {
+ // let width = i32(i32(i32(8)));
+ // let height = i32(i32(i32(4)));
+ // [[stage(compute), workgroup_size(width, height)]]
+ // fn main() {}
+ GlobalConst("width", ty.i32(),
+ Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 8))));
+ GlobalConst("height", ty.i32(),
+ Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 4))));
+ auto* func = Func(
+ "main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
+ EXPECT_EQ(func_sem->workgroup_size()[1].value, 4u);
+ EXPECT_EQ(func_sem->workgroup_size()[2].value, 1u);
+ EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
+ // [[override(0)]] let width = 16;
+ // [[override(1)]] let height = 8;
+ // [[override(2)]] let depth = 2;
+ // [[stage(compute), workgroup_size(width, height, depth)]]
+ // fn main() {}
+ auto* width = GlobalConst("width", ty.i32(), Expr(16), {Override(0)});
+ auto* height = GlobalConst("height", ty.i32(), Expr(8), {Override(1)});
+ auto* depth = GlobalConst("depth", ty.i32(), Expr(2), {Override(2)});
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize("width", "height", "depth")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->workgroup_size()[0].value, 16u);
+ EXPECT_EQ(func_sem->workgroup_size()[1].value, 8u);
+ EXPECT_EQ(func_sem->workgroup_size()[2].value, 2u);
+ EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, width);
+ EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
+ EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, depth);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
+ // [[override(0)]] let width : i32;
+ // [[override(1)]] let height : i32;
+ // [[override(2)]] let depth : i32;
+ // [[stage(compute), workgroup_size(width, height, depth)]]
+ // fn main() {}
+ auto* width = GlobalConst("width", ty.i32(), nullptr, {Override(0)});
+ auto* height = GlobalConst("height", ty.i32(), nullptr, {Override(1)});
+ auto* depth = GlobalConst("depth", ty.i32(), nullptr, {Override(2)});
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize("width", "height", "depth")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->workgroup_size()[0].value, 0u);
+ EXPECT_EQ(func_sem->workgroup_size()[1].value, 0u);
+ EXPECT_EQ(func_sem->workgroup_size()[2].value, 0u);
+ EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, width);
+ EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
+ EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, depth);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
+ // [[override(1)]] let height = 2;
+ // let depth = 3;
+ // [[stage(compute), workgroup_size(8, height, depth)]]
+ // fn main() {}
+ auto* height = GlobalConst("height", ty.i32(), Expr(2), {Override(0)});
+ GlobalConst("depth", ty.i32(), Expr(3));
+ auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+ {Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(8, "height", "depth")});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
+ EXPECT_EQ(func_sem->workgroup_size()[1].value, 2u);
+ EXPECT_EQ(func_sem->workgroup_size()[2].value, 3u);
+ EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+ EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
+ EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
auto* st = Structure("S", {Member("first_member", ty.i32()),
Member("second_member", ty.f32())});
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 68f55c7..502b60d 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -926,7 +926,7 @@
},
{
Stage(ast::PipelineStage::kCompute),
- create<ast::WorkgroupDecoration>(2u, 4u, 6u),
+ WorkgroupSize(2, 4, 6),
});
GeneratorImpl& gen = Build();
diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc
index c394926..91bc005 100644
--- a/src/writer/spirv/builder_function_decoration_test.cc
+++ b/src/writer/spirv/builder_function_decoration_test.cc
@@ -200,7 +200,7 @@
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize) {
auto* func = Func("main", {}, ty.void_(), ast::StatementList{},
ast::DecorationList{
- create<ast::WorkgroupDecoration>(2u, 4u, 6u),
+ WorkgroupSize(2, 4, 6),
Stage(ast::PipelineStage::kCompute),
});
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 45a77b8..ef96c53 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -601,12 +601,28 @@
first = false;
if (auto* workgroup = deco->As<ast::WorkgroupDecoration>()) {
- uint32_t x = 0;
- uint32_t y = 0;
- uint32_t z = 0;
- std::tie(x, y, z) = workgroup->values();
- out_ << "workgroup_size(" << std::to_string(x) << ", "
- << std::to_string(y) << ", " << std::to_string(z) << ")";
+ auto values = workgroup->values();
+ out_ << "workgroup_size(";
+ for (int i = 0; i < 3; i++) {
+ if (values[i]) {
+ if (i > 0) {
+ out_ << ", ";
+ }
+ if (auto* ident = values[i]->As<ast::IdentifierExpression>()) {
+ if (!EmitIdentifier(ident)) {
+ return false;
+ }
+ } else if (auto* scalar =
+ values[i]->As<ast::ScalarConstructorExpression>()) {
+ if (!EmitScalarConstructor(scalar)) {
+ return false;
+ }
+ } else {
+ TINT_ICE(diagnostics_) << "Unsupported workgroup_size expression";
+ }
+ }
+ }
+ out_ << ")";
} else if (deco->Is<ast::StructBlockDecoration>()) {
out_ << "block";
} else if (auto* stage = deco->As<ast::StageDecoration>()) {
diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc
index fc69ccf..0e19f5a 100644
--- a/src/writer/wgsl/generator_impl_function_test.cc
+++ b/src/writer/wgsl/generator_impl_function_test.cc
@@ -69,13 +69,10 @@
TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) {
auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
- ast::StatementList{
- create<ast::DiscardStatement>(),
- Return(),
- },
+ ast::StatementList{Return()},
ast::DecorationList{
Stage(ast::PipelineStage::kCompute),
- create<ast::WorkgroupDecoration>(2u, 4u, 6u),
+ WorkgroupSize(2, 4, 6),
});
GeneratorImpl& gen = Build();
@@ -85,20 +82,19 @@
ASSERT_TRUE(gen.EmitFunction(func));
EXPECT_EQ(gen.result(), R"( [[stage(compute), workgroup_size(2, 4, 6)]]
fn my_func() {
- discard;
return;
}
)");
}
-TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) {
+TEST_F(WgslGeneratorImplTest,
+ Emit_Function_WithDecoration_WorkgroupSize_WithIdent) {
+ GlobalConst("height", ty.i32(), Expr(2));
auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
- ast::StatementList{
- create<ast::DiscardStatement>(),
- Return(),
- },
+ ast::StatementList{Return()},
ast::DecorationList{
- Stage(ast::PipelineStage::kFragment),
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(2, "height"),
});
GeneratorImpl& gen = Build();
@@ -106,9 +102,8 @@
gen.increment_indent();
ASSERT_TRUE(gen.EmitFunction(func));
- EXPECT_EQ(gen.result(), R"( [[stage(fragment)]]
+ EXPECT_EQ(gen.result(), R"( [[stage(compute), workgroup_size(2, height)]]
fn my_func() {
- discard;
return;
}
)");
diff --git a/test/samples/function.wgsl.expected.wgsl b/test/samples/function.wgsl.expected.wgsl
index d8a4c35..4cd4fb6 100644
--- a/test/samples/function.wgsl.expected.wgsl
+++ b/test/samples/function.wgsl.expected.wgsl
@@ -2,6 +2,6 @@
return (((2.0 * 3.0) - 4.0) / 5.0);
}
-[[stage(compute), workgroup_size(2, 1, 1)]]
+[[stage(compute), workgroup_size(2)]]
fn ep() {
}