Update `workgroup_size` to use `expression`.

This CL updates the `workgroup_size` attribute to use `expression`
values instead of `primary_expression`.

Bug: tint:1633
Change-Id: I0afbabd8ee61943469f04a55d56f85920563e2da
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99960
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index 5da80d6..c797aab 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2971,6 +2971,15 @@
     /// Creates an ast::WorkgroupAttribute
     /// @param source the source information
     /// @param x the x dimension expression
+    /// @returns the workgroup attribute pointer
+    template <typename EXPR_X>
+    const ast::WorkgroupAttribute* WorkgroupSize(const Source& source, EXPR_X&& x) {
+        return WorkgroupSize(source, std::forward<EXPR_X>(x), nullptr, nullptr);
+    }
+
+    /// Creates an ast::WorkgroupAttribute
+    /// @param source the source information
+    /// @param x the x dimension expression
     /// @param y the y dimension expression
     /// @returns the workgroup attribute pointer
     template <typename EXPR_X, typename EXPR_Y>
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index eff842b..bf95b4e 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -3405,28 +3405,25 @@
 }
 
 // attribute
-//   : ATTR 'align' PAREN_LEFT expression attrib_end
-//   | ATTR 'binding' PAREN_LEFT expression attrib_end
-//   | ATTR 'builtin' PAREN_LEFT builtin_value_name attrib_end
+//   : ATTR 'align' PAREN_LEFT expression COMMA? PAREN_RIGHT
+//   | ATTR 'binding' PAREN_LEFT expression COMMA? PAREN_RIGHT
+//   | ATTR 'builtin' PAREN_LEFT builtin_value_name COMMA? PAREN_RIGHT
 //   | ATTR 'const'
-//   | ATTR 'group' PAREN_LEFT expression attrib_end
-//   | ATTR 'id' PAREN_LEFT expression attrib_end
-//   | ATTR 'interpolate' PAREN_LEFT interpolation_type_name attrib_end
+//   | ATTR 'group' PAREN_LEFT expression COMMA? PAREN_RIGHT
+//   | ATTR 'id' PAREN_LEFT expression COMMA? PAREN_RIGHT
+//   | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA? PAREN_RIGHT
 //   | ATTR 'interpolate' PAREN_LEFT interpolation_type_name COMMA
-//                                   interpolation_sample_name attrib_end
+//                                   interpolation_sample_name COMMA? PAREN_RIGHT
 //   | ATTR 'invariant'
-//   | ATTR 'location' PAREN_LEFT expression attrib_end
-//   | ATTR 'size' PAREN_LEFT expression attrib_end
-//   | ATTR 'workgroup_size' PAREN_LEFT expression attrib_end
-//   | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression attrib_end
-//   | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA expression attrib_end
+//   | ATTR 'location' PAREN_LEFT expression COMMA? PAREN_RIGHT
+//   | ATTR 'size' PAREN_LEFT expression COMMA? PAREN_RIGHT
+//   | ATTR 'workgroup_size' PAREN_LEFT expression COMMA? PAREN_RIGHT
+//   | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA? PAREN_RIGHT
+//   | ATTR 'workgroup_size' PAREN_LEFT expression COMMA expression COMMA
+//                                      expression COMMA? PAREN_RIGHT
 //   | ATTR 'vertex'
 //   | ATTR 'fragment'
 //   | ATTR 'compute'
-//
-// attrib_end
-//   : COMMA? PAREN_RIGHT
-//
 Maybe<const ast::Attribute*> ParserImpl::attribute() {
     using Result = Maybe<const ast::Attribute*>;
     auto& t = next();
@@ -3603,7 +3600,7 @@
             const ast::Expression* y = nullptr;
             const ast::Expression* z = nullptr;
 
-            auto expr = primary_expression();
+            auto expr = expression();
             if (expr.errored) {
                 return Failure::kErrored;
             } else if (!expr.matched) {
@@ -3613,7 +3610,7 @@
 
             if (match(Token::Type::kComma)) {
                 if (!peek_is(Token::Type::kParenRight)) {
-                    expr = primary_expression();
+                    expr = expression();
                     if (expr.errored) {
                         return Failure::kErrored;
                     } else if (!expr.matched) {
@@ -3623,7 +3620,7 @@
 
                     if (match(Token::Type::kComma)) {
                         if (!peek_is(Token::Type::kParenRight)) {
-                            expr = primary_expression();
+                            expr = expression();
                             if (expr.errored) {
                                 return Failure::kErrored;
                             } else if (!expr.matched) {
diff --git a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
index 8e56ee7..60e8f86 100644
--- a/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_function_attribute_test.cc
@@ -41,6 +41,35 @@
     EXPECT_EQ(values[2], nullptr);
 }
 
+TEST_F(ParserImplTest, Attribute_Workgroup_Expression) {
+    auto p = parser("workgroup_size(4 + 2)");
+    auto attr = p->attribute();
+    EXPECT_TRUE(attr.matched);
+    EXPECT_FALSE(attr.errored);
+    ASSERT_NE(attr.value, nullptr) << p->error();
+    ASSERT_FALSE(p->has_error());
+    auto* func_attr = attr.value->As<ast::Attribute>();
+    ASSERT_NE(func_attr, nullptr);
+    ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+    auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+    ASSERT_TRUE(values[0]->Is<ast::BinaryExpression>());
+    auto* expr = values[0]->As<ast::BinaryExpression>();
+    EXPECT_EQ(expr->op, ast::BinaryOp::kAdd);
+
+    EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 4);
+    EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+
+    EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
+    EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+
+    EXPECT_EQ(values[1], nullptr);
+    EXPECT_EQ(values[2], nullptr);
+}
+
 TEST_F(ParserImplTest, Attribute_Workgroup_1Param_TrailingComma) {
     auto p = parser("workgroup_size(4,)");
     auto attr = p->attribute();
@@ -99,6 +128,39 @@
     EXPECT_EQ(values[2], nullptr);
 }
 
+TEST_F(ParserImplTest, Attribute_Workgroup_2Param_Expression) {
+    auto p = parser("workgroup_size(4, 5 - 2)");
+    auto attr = p->attribute();
+    EXPECT_TRUE(attr.matched);
+    EXPECT_FALSE(attr.errored);
+    ASSERT_NE(attr.value, nullptr) << p->error();
+    ASSERT_FALSE(p->has_error());
+    auto* func_attr = attr.value->As<ast::Attribute>();
+    ASSERT_NE(func_attr, nullptr) << p->error();
+    ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+    auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+    ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
+    EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
+    EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+
+    ASSERT_TRUE(values[1]->Is<ast::BinaryExpression>());
+    auto* expr = values[1]->As<ast::BinaryExpression>();
+    EXPECT_EQ(expr->op, ast::BinaryOp::kSubtract);
+
+    EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 5);
+    EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+
+    EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 2);
+    EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+
+    EXPECT_EQ(values[2], nullptr);
+}
+
 TEST_F(ParserImplTest, Attribute_Workgroup_2Param_TrailingComma) {
     auto p = parser("workgroup_size(4, 5,)");
     auto attr = p->attribute();
@@ -164,6 +226,42 @@
               ast::IntLiteralExpression::Suffix::kNone);
 }
 
+TEST_F(ParserImplTest, Attribute_Workgroup_3Param_Expression) {
+    auto p = parser("workgroup_size(4, 5, 6 << 1)");
+    auto attr = p->attribute();
+    EXPECT_TRUE(attr.matched);
+    EXPECT_FALSE(attr.errored);
+    ASSERT_NE(attr.value, nullptr) << p->error();
+    ASSERT_FALSE(p->has_error());
+    auto* func_attr = attr.value->As<ast::Attribute>();
+    ASSERT_NE(func_attr, nullptr);
+    ASSERT_TRUE(func_attr->Is<ast::WorkgroupAttribute>());
+
+    auto values = func_attr->As<ast::WorkgroupAttribute>()->Values();
+
+    ASSERT_TRUE(values[0]->Is<ast::IntLiteralExpression>());
+    EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->value, 4);
+    EXPECT_EQ(values[0]->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+
+    ASSERT_TRUE(values[1]->Is<ast::IntLiteralExpression>());
+    EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->value, 5);
+    EXPECT_EQ(values[1]->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+
+    ASSERT_TRUE(values[2]->Is<ast::BinaryExpression>());
+    auto* expr = values[2]->As<ast::BinaryExpression>();
+    EXPECT_EQ(expr->op, ast::BinaryOp::kShiftLeft);
+
+    EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->value, 6);
+    EXPECT_EQ(expr->lhs->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+
+    EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->value, 1);
+    EXPECT_EQ(expr->rhs->As<ast::IntLiteralExpression>()->suffix,
+              ast::IntLiteralExpression::Suffix::kNone);
+}
+
 TEST_F(ParserImplTest, Attribute_Workgroup_3Param_TrailingComma) {
     auto p = parser("workgroup_size(4, 5, 6,)");
     auto attr = p->attribute();
diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc
index f699504..22f2620 100644
--- a/src/tint/resolver/function_validation_test.cc
+++ b/src/tint/resolver/function_validation_test.cc
@@ -545,6 +545,19 @@
     ASSERT_TRUE(r()->Resolve()) << r()->error();
 }
 
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Expr) {
+    // @compute @workgroup_size(1 + 2)
+    // fn main() {}
+
+    Func("main", utils::Empty, ty.void_(), utils::Empty,
+         utils::Vector{
+             Stage(ast::PipelineStage::kCompute),
+             WorkgroupSize(Source{{12, 34}}, Add(1_u, 2_u)),
+         });
+
+    ASSERT_TRUE(r()->Resolve()) << r()->error();
+}
+
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_MismatchType_U32) {
     // @compute @workgroup_size(1u, 2, 3_i)
     // fn main() {}
@@ -750,13 +763,43 @@
               "overridable of type abstract-integer, i32 or u32");
 }
 
-TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr) {
-    // @compute @workgroup_size(i32(1))
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
+    // @compute @workgroup_size(1 << 2 + 4)
     // fn main() {}
     Func("main", utils::Empty, ty.void_(), utils::Empty,
          utils::Vector{
              Stage(ast::PipelineStage::kCompute),
-             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), 1_i)),
+             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+         });
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              "12:34 error: workgroup_size argument must be either a literal, constant, or "
+              "overridable of type abstract-integer, i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
+    // @compute @workgroup_size(1, 1 << 2 + 4)
+    // fn main() {}
+    Func("main", utils::Empty, ty.void_(), utils::Empty,
+         utils::Vector{
+             Stage(ast::PipelineStage::kCompute),
+             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
+         });
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_EQ(r()->error(),
+              "12:34 error: workgroup_size argument must be either a literal, constant, or "
+              "overridable of type abstract-integer, i32 or u32");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
+    // @compute @workgroup_size(1, 1, 1 << 2 + 4)
+    // fn main() {}
+    Func("main", utils::Empty, ty.void_(), utils::Empty,
+         utils::Vector{
+             Stage(ast::PipelineStage::kCompute),
+             WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
          });
 
     EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index ed07399..90ec254 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -965,7 +965,7 @@
 
     for (size_t i = 0; i < 3; i++) {
         // Each argument to this attribute can either be a literal, an identifier for a module-scope
-        // constants, or nullptr if not specified.
+        // constants, a constant expression, or nullptr if not specified.
         auto* value = values[i];
         if (!value) {
             break;
@@ -1023,7 +1023,7 @@
                 ws[i].value = 0;
                 continue;
             }
-        } else if (values[i]->Is<ast::LiteralExpression>()) {
+        } else if (values[i]->Is<ast::LiteralExpression>() || args[i]->ConstantValue()) {
             value = materialized->ConstantValue();
         } else {
             AddError(kErrBadExpr, values[i]->source);
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl b/test/tint/shader_io/compute_workgroup_expression.wgsl
new file mode 100644
index 0000000..a465a76
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl
@@ -0,0 +1,7 @@
+@id(0) override x_dim = 2;
+
+@compute
+@workgroup_size(1 + 2, x_dim, clamp((1 - 2) + 4, 0, 5))
+fn main() {
+}
+
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl
new file mode 100644
index 0000000..2098cb3
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.dxc.hlsl
@@ -0,0 +1,9 @@
+#ifndef WGSL_SPEC_CONSTANT_0
+#define WGSL_SPEC_CONSTANT_0 2
+#endif
+static const int x_dim = WGSL_SPEC_CONSTANT_0;
+
+[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)]
+void main() {
+  return;
+}
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl
new file mode 100644
index 0000000..2098cb3
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.fxc.hlsl
@@ -0,0 +1,9 @@
+#ifndef WGSL_SPEC_CONSTANT_0
+#define WGSL_SPEC_CONSTANT_0 2
+#endif
+static const int x_dim = WGSL_SPEC_CONSTANT_0;
+
+[numthreads(3, WGSL_SPEC_CONSTANT_0, 3)]
+void main() {
+  return;
+}
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl
new file mode 100644
index 0000000..6ac96fa
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.glsl
@@ -0,0 +1,14 @@
+#version 310 es
+
+#ifndef WGSL_SPEC_CONSTANT_0
+#define WGSL_SPEC_CONSTANT_0 2
+#endif
+const int x_dim = WGSL_SPEC_CONSTANT_0;
+void tint_symbol() {
+}
+
+layout(local_size_x = 3, local_size_y = WGSL_SPEC_CONSTANT_0, local_size_z = 3) in;
+void main() {
+  tint_symbol();
+  return;
+}
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl
new file mode 100644
index 0000000..d9f2428
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.msl
@@ -0,0 +1,9 @@
+#include <metal_stdlib>
+
+using namespace metal;
+constant int x_dim [[function_constant(0)]];
+
+kernel void tint_symbol() {
+  return;
+}
+
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm
new file mode 100644
index 0000000..bcd4536
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.spvasm
@@ -0,0 +1,26 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 12
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpName %x_dim "x_dim"
+               OpName %main "main"
+               OpDecorate %x_dim SpecId 0
+               OpDecorate %11 SpecId 0
+               OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
+        %int = OpTypeInt 32 1
+      %x_dim = OpSpecConstant %int 2
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+     %v3uint = OpTypeVector %uint 3
+     %uint_3 = OpConstant %uint 3
+         %11 = OpSpecConstant %uint 2
+%gl_WorkGroupSize = OpSpecConstantComposite %v3uint %uint_3 %11 %uint_3
+       %main = OpFunction %void None %3
+          %6 = OpLabel
+               OpReturn
+               OpFunctionEnd
diff --git a/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl
new file mode 100644
index 0000000..c2b260b
--- /dev/null
+++ b/test/tint/shader_io/compute_workgroup_expression.wgsl.expected.wgsl
@@ -0,0 +1,5 @@
+@id(0) override x_dim = 2;
+
+@compute @workgroup_size((1 + 2), x_dim, clamp(((1 - 2) + 4), 0, 5))
+fn main() {
+}