Update `workgroup_size` to use `expression`.
This CL updates the `workgroup_size` attribute to use `expression`
values instead of `primary_expression`.
Bug: tint:1633
Change-Id: I0afbabd8ee61943469f04a55d56f85920563e2da
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99960
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 5da80d6..c797aab 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2971,6 +2971,15 @@
/// Creates an ast::WorkgroupAttribute
/// @param source the source information
/// @param x the x dimension expression
+ /// @returns the workgroup attribute pointer
+ template <typename EXPR_X>
+ const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x) {
+ return WorkgroupSize(source, std::forward<EXPR_X>(x), nullptr, nullptr);
+ }
+
+ /// Creates an ast::WorkgroupAttribute
+ /// @param source the source information
+ /// @param x the x dimension expression
/// @param y the y dimension expression
/// @returns the workgroup attribute pointer
template <typename EXPR_X, typename EXPR_Y>
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index eff842b..bf95b4e 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -3405,28 +3405,25 @@
}
// attribute
-// : ATTR 'align' PAREN_LEFT expression attrib_end
-// | ATTR 'binding' PAREN_LEFT expression attrib_end
-// | ATTR 'builtin' PAREN_LEFT builtin_value_name attrib_end
+// : ATTR 'align' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'binding' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'builtin' PAREN_LEFT builtin_value_name COMMA? PAREN_RIGHT
// | ATTR 'const'
-// | ATTR 'group' PAREN_LEFT expression attrib_end
-// | ATTR 'id' PAREN_LEFT expression attrib_end
-// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name attrib_end
+// | ATTR 'group' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'id' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA? PAREN_RIGHT
// | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA
-// interpolation_sample_name attrib_end
+// interpolation_sample_name COMMA? PAREN_RIGHT
// | ATTR 'invariant'
-// | ATTR 'location' PAREN_LEFT expression attrib_end
-// | ATTR 'size' PAREN_LEFT expression attrib_end
-// | ATTR 'workgroup_size' PAREN_LEFT expression attrib_end
-// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression attrib_end
-// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA expression attrib_end
+// | ATTR 'location' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'size' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA? PAREN_RIGHT
+// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA? PAREN_RIGHT
+// | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA
+// expression COMMA? PAREN_RIGHT
// | ATTR 'vertex'
// | ATTR 'fragment'
// | ATTR 'compute'
-//
-// attrib_end
-// : COMMA? PAREN_RIGHT
-//
Maybe<const ast::Attribute*> ParserImpl::attribute() {
using Result = Maybe<const ast::Attribute*>;
auto& t = next();
@@ -3603,7 +3600,7 @@
const ast::Expression* y = nullptr;
const ast::Expression* z = nullptr;
- auto expr = primary_expression();
+ auto expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {
@@ -3613,7 +3610,7 @@
if (match(Token::Type::kComma)) {
if (!peek_is(Token::Type::kParenRight)) {
- expr = primary_expression();
+ expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {
@@ -3623,7 +3620,7 @@
if (match(Token::Type::kComma)) {
if (!peek_is(Token::Type::kParenRight)) {
- expr = primary_expression();
+ expr = expression();
if (expr.errored) {
return Failure::kErrored;
} else if (!expr.matched) {
diff --git a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
index 8e56ee7..60e8f86 100644
--- a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
@@ -41,6 +41,35 @@
EXPECT_EQ(values[2], nullptr);
}
+TEST_F(ParserImplTest, Attribute_Workgroup_Expression) {
+ auto p = parser("workgroup_size(4 + 2)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr) << p->error();
+ ASSERT_FALSE(p->has_error());
+ auto* func_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(func_attr, nullptr);
+ ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+ auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+ ASSERT_TRUE(values[0]->Is<ast::BinaryExpression>());
+ auto* expr = values[0]->As<ast::BinaryExpression>();
+ EXPECT_EQ(expr->op, ast::BinaryOp::kAdd);
+
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 4);
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(values[1], nullptr);
+ EXPECT_EQ(values[2], nullptr);
+}
+
TEST_F(ParserImplTest, Attribute_Workgroup_1Param_TrailingComma) {
auto p = parser("workgroup_size(4,)");
auto attr = p->attribute();
@@ -99,6 +128,39 @@
EXPECT_EQ(values[2], nullptr);
}
+TEST_F(ParserImplTest, Attribute_Workgroup_2Param_Expression) {
+ auto p = parser("workgroup_size(4, 5 - 2)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr) << p->error();
+ ASSERT_FALSE(p->has_error());
+ auto* func_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(func_attr, nullptr) << p->error();
+ ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+ auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+ ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
+ EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(values[1]->Is<ast::BinaryExpression>());
+ auto* expr = values[1]->As<ast::BinaryExpression>();
+ EXPECT_EQ(expr->op, ast::BinaryOp::kSubtract);
+
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 5);
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(values[2], nullptr);
+}
+
TEST_F(ParserImplTest, Attribute_Workgroup_2Param_TrailingComma) {
auto p = parser("workgroup_size(4, 5,)");
auto attr = p->attribute();
@@ -164,6 +226,42 @@
ast::IntLiteralExpression::Suffix::kNone);
}
+TEST_F(ParserImplTest, Attribute_Workgroup_3Param_Expression) {
+ auto p = parser("workgroup_size(4, 5, 6 << 1)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr) << p->error();
+ ASSERT_FALSE(p->has_error());
+ auto* func_attr = attr.value->As<ast::Attribute>();
+ ASSERT_NE(func_attr, nullptr);
+ ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+ auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+ ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
+ EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(values[1]->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->value, 5);
+ EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ ASSERT_TRUE(values[2]->Is<ast::BinaryExpression>());
+ auto* expr = values[2]->As<ast::BinaryExpression>();
+ EXPECT_EQ(expr->op, ast::BinaryOp::kShiftLeft);
+
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 6);
+ EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 1);
+ EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+ ast::IntLiteralExpression::Suffix::kNone);
+}
+
TEST_F(ParserImplTest, Attribute_Workgroup_3Param_TrailingComma) {
auto p = parser("workgroup_size(4, 5, 6,)");
auto attr = p->attribute();
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index f699504..22f2620 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -545,6 +545,19 @@
ASSERT_TRUE(r()->Resolve()) << r()->error();
}
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Expr) {
+ // @compute @workgroup_size(1 + 2)
+ // fn main() {}
+
+ Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Source{{12, 34}}, Add(1_u, 2_u)),
+ });
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
// @compute @workgroup_size(1u, 2, 3_i)
// fn main() {}
@@ -750,13 +763,43 @@
"overridable of type abstract-integer, i32 or u32");
}
-TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr) {
- // @compute @workgroup_size(i32(1))
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
+ // @compute @workgroup_size(1 << 2 + 4)
// fn main() {}
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
- WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 1_i)),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be either a literal, constant, or "
+ "overridable of type abstract-integer, i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
+ // @compute @workgroup_size(1, 1 << 2 + 4)
+ // fn main() {}
+ Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+ });
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: workgroup_size argument must be either a literal, constant, or "
+ "overridable of type abstract-integer, i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
+ // @compute @workgroup_size(1, 1, 1 << 2 + 4)
+ // fn main() {}
+ Func("main", utils::Empty, ty.void_(), utils::Empty,
+ utils::Vector{
+ Stage(ast::PipelineStage::kCompute),
+ WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
});
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index ed07399..90ec254 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -965,7 +965,7 @@
for (size_t i = 0; i < 3; i++) {
// Each argument to this attribute can either be a literal, an identifier for a module-scope
- // constants, or nullptr if not specified.
+ // constants, a constant expression, or nullptr if not specified.
auto* value = values[i];
if (!value) {
break;
@@ -1023,7 +1023,7 @@
ws[i].value = 0;
continue;
}
- } else if (values[i]->Is<ast::LiteralExpression>()) {
+ } else if (values[i]->Is<ast::LiteralExpression>() || args[i]->ConstantValue()) {
value = materialized->ConstantValue();
} else {
AddError(kErrBadExpr, values[i]->source);
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl b/test/tint/shader_io/compute_workgroup_expression.wgsl
new file mode 100644
index 0000000..a465a76
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl
@@ -0,0 +1,7 @@
+@id(0) override x_dim = 2;
+
+@compute
+@workgroup_size(1 + 2, x_dim, clamp((1 - 2) + 4, 0, 5))
+fn main() {
+}
+
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..2098cb3
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl
@@ -0,0 +1,9 @@
+#ifndef WGSL_SPEC_CONSTANT_0
+#define WGSL_SPEC_CONSTANT_0 2
+#endif
+static const int x_dim = WGSL_SPEC_CONSTANT_0;
+
+[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)]
+void main() {
+ return;
+}
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..2098cb3
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl
@@ -0,0 +1,9 @@
+#ifndef WGSL_SPEC_CONSTANT_0
+#define WGSL_SPEC_CONSTANT_0 2
+#endif
+static const int x_dim = WGSL_SPEC_CONSTANT_0;
+
+[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)]
+void main() {
+ return;
+}
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl
new file mode 100644
index 0000000..6ac96fa
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl
@@ -0,0 +1,14 @@
+#version 310 es
+
+#ifndef WGSL_SPEC_CONSTANT_0
+#define WGSL_SPEC_CONSTANT_0 2
+#endif
+const int x_dim = WGSL_SPEC_CONSTANT_0;
+void tint_symbol() {
+}
+
+layout(local_size_x = 3, local_size_y = WGSL_SPEC_CONSTANT_0, local_size_z = 3) in;
+void main() {
+ tint_symbol();
+ return;
+}
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl
new file mode 100644
index 0000000..d9f2428
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl
@@ -0,0 +1,9 @@
+#include <metal_stdlib>
+
+using namespace metal;
+constant int x_dim [[function_constant(0)]];
+
+kernel void tint_symbol() {
+ return;
+}
+
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm
new file mode 100644
index 0000000..bcd4536
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm
@@ -0,0 +1,26 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 12
+; Schema: 0
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main"
+ OpName %x_dim "x_dim"
+ OpName %main "main"
+ OpDecorate %x_dim SpecId 0
+ OpDecorate %11 SpecId 0
+ OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
+ %int = OpTypeInt 32 1
+ %x_dim = OpSpecConstant %int 2
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %v3uint = OpTypeVector %uint 3
+ %uint_3 = OpConstant %uint 3
+ %11 = OpSpecConstant %uint 2
+%gl_WorkGroupSize = OpSpecConstantComposite %v3uint %uint_3 %11 %uint_3
+ %main = OpFunction %void None %3
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl
new file mode 100644
index 0000000..c2b260b
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl
@@ -0,0 +1,5 @@
+@id(0) override x_dim = 2;
+
+@compute @workgroup_size((1 + 2), x_dim, clamp(((1 - 2) + 4), 0, 5))
+fn main() {
+}