ast: Support non-literal workgroup_size parameters

Change the type of the values in an ast::WorkgroupDecoration to be
ast::Expression nodes, so that they can represent both
ast::ScalarExpression (literal) and ast::IdentifierExpression
(module-scope constant).

The Resolver processes these nodes to produce a uint32_t for the
default value on each dimension, and captures a reference to the
module-scope constant if it is overridable (which will soon be used by
the inspector and backends).

The WGSL parser now uses `primary_expression` to parse arguments to
workgroup_size.

Also added some WorkgroupSize() helpers to ProgramBuilder.

Bug: tint:713
Change-Id: I44b7b0021b925c84f25f65e26dc7da6b19ede508
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51262
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 44ba686..a634e6f 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -107,7 +107,7 @@
         ProgramBuilder b2;
         b1.Func("func", VariableList{}, b1.ty.void_(), StatementList{},
                 DecorationList{
-                    b2.create<WorkgroupDecoration>(2, 4, 6),
+                    b2.WorkgroupSize(2, 4, 6),
                 });
       },
       "internal compiler error");
@@ -121,7 +121,7 @@
         b1.Func("func", VariableList{}, b1.ty.void_(), StatementList{},
                 DecorationList{},
                 DecorationList{
-                    b2.create<WorkgroupDecoration>(2, 4, 6),
+                    b2.WorkgroupSize(2, 4, 6),
                 });
       },
       "internal compiler error");
@@ -159,10 +159,14 @@
                  StatementList{
                      create<DiscardStatement>(),
                  },
-                 DecorationList{create<WorkgroupDecoration>(2, 4, 6)});
+                 DecorationList{WorkgroupSize(2, 4, 6)});
 
   EXPECT_EQ(str(f), R"(Function func -> __void
-WorkgroupDecoration{2 4 6}
+WorkgroupDecoration{
+  ScalarConstructor[not set]{2}
+  ScalarConstructor[not set]{4}
+  ScalarConstructor[not set]{6}
+}
 ()
 {
   Discard{}
diff --git a/src/ast/int_literal.h b/src/ast/int_literal.h
index 1b5c7ef..4a12ef6 100644
--- a/src/ast/int_literal.h
+++ b/src/ast/int_literal.h
@@ -25,6 +25,9 @@
  public:
   ~IntLiteral() override;
 
+  /// @returns the literal value as an i32
+  int32_t value_as_i32() const { return static_cast<int32_t>(value_); }
+
   /// @returns the literal value as a u32
   uint32_t value_as_u32() const { return value_; }
 
diff --git a/src/ast/workgroup_decoration.cc b/src/ast/workgroup_decoration.cc
index ba7d3ff..36a0d40 100644
--- a/src/ast/workgroup_decoration.cc
+++ b/src/ast/workgroup_decoration.cc
@@ -23,36 +23,36 @@
 
 WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
                                          const Source& source,
-                                         uint32_t x)
-    : WorkgroupDecoration(program_id, source, x, 1, 1) {}
-
-WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
-                                         const Source& source,
-                                         uint32_t x,
-                                         uint32_t y)
-    : WorkgroupDecoration(program_id, source, x, y, 1) {}
-
-WorkgroupDecoration::WorkgroupDecoration(ProgramID program_id,
-                                         const Source& source,
-                                         uint32_t x,
-                                         uint32_t y,
-                                         uint32_t z)
+                                         ast::Expression* x,
+                                         ast::Expression* y,
+                                         ast::Expression* z)
     : Base(program_id, source), x_(x), y_(y), z_(z) {}
 
 WorkgroupDecoration::~WorkgroupDecoration() = default;
 
-void WorkgroupDecoration::to_str(const sem::Info&,
+void WorkgroupDecoration::to_str(const sem::Info& sem,
                                  std::ostream& out,
                                  size_t indent) const {
   make_indent(out, indent);
-  out << "WorkgroupDecoration{" << x_ << " " << y_ << " " << z_ << "}"
-      << std::endl;
+  out << "WorkgroupDecoration{" << std::endl;
+  x_->to_str(sem, out, indent + 2);
+  if (y_) {
+    y_->to_str(sem, out, indent + 2);
+    if (z_) {
+      z_->to_str(sem, out, indent + 2);
+    }
+  }
+  make_indent(out, indent);
+  out << "}" << std::endl;
 }
 
 WorkgroupDecoration* WorkgroupDecoration::Clone(CloneContext* ctx) const {
   // Clone arguments outside of create() call to have deterministic ordering
   auto src = ctx->Clone(source());
-  return ctx->dst->create<WorkgroupDecoration>(src, x_, y_, z_);
+  auto* x = ctx->Clone(x_);
+  auto* y = ctx->Clone(y_);
+  auto* z = ctx->Clone(z_);
+  return ctx->dst->create<WorkgroupDecoration>(src, x, y, z);
 }
 
 }  // namespace ast
diff --git a/src/ast/workgroup_decoration.h b/src/ast/workgroup_decoration.h
index 5838168..fa68e87 100644
--- a/src/ast/workgroup_decoration.h
+++ b/src/ast/workgroup_decoration.h
@@ -15,47 +15,35 @@
 #ifndef SRC_AST_WORKGROUP_DECORATION_H_
 #define SRC_AST_WORKGROUP_DECORATION_H_
 
-#include <tuple>
+#include <array>
 
 #include "src/ast/decoration.h"
 
 namespace tint {
 namespace ast {
 
+// Forward declaration
+class Expression;
+
 /// A workgroup decoration
 class WorkgroupDecoration : public Castable<WorkgroupDecoration, Decoration> {
  public:
   /// constructor
   /// @param program_id the identifier of the program that owns this node
   /// @param source the source of this decoration
-  /// @param x the workgroup x dimension size
-  WorkgroupDecoration(ProgramID program_id, const Source& source, uint32_t x);
-  /// constructor
-  /// @param program_id the identifier of the program that owns this node
-  /// @param source the source of this decoration
-  /// @param x the workgroup x dimension size
-  /// @param y the workgroup x dimension size
+  /// @param x the workgroup x dimension expression
+  /// @param y the optional workgroup y dimension expression
+  /// @param z the optional workgroup z dimension expression
   WorkgroupDecoration(ProgramID program_id,
                       const Source& source,
-                      uint32_t x,
-                      uint32_t y);
-  /// constructor
-  /// @param program_id the identifier of the program that owns this node
-  /// @param source the source of this decoration
-  /// @param x the workgroup x dimension size
-  /// @param y the workgroup x dimension size
-  /// @param z the workgroup x dimension size
-  WorkgroupDecoration(ProgramID program_id,
-                      const Source& source,
-                      uint32_t x,
-                      uint32_t y,
-                      uint32_t z);
+                      ast::Expression* x,
+                      ast::Expression* y = nullptr,
+                      ast::Expression* z = nullptr);
+
   ~WorkgroupDecoration() override;
 
   /// @returns the workgroup dimensions
-  std::tuple<uint32_t, uint32_t, uint32_t> values() const {
-    return {x_, y_, z_};
-  }
+  std::array<ast::Expression*, 3> values() const { return {x_, y_, z_}; }
 
   /// Outputs the decoration to the given stream
   /// @param sem the semantic info for the program
@@ -72,9 +60,9 @@
   WorkgroupDecoration* Clone(CloneContext* ctx) const override;
 
  private:
-  uint32_t const x_;
-  uint32_t const y_;
-  uint32_t const z_;
+  ast::Expression* x_ = nullptr;
+  ast::Expression* y_ = nullptr;
+  ast::Expression* z_ = nullptr;
 };
 
 }  // namespace ast
diff --git a/src/ast/workgroup_decoration_test.cc b/src/ast/workgroup_decoration_test.cc
index c5c697f..1f0756d 100644
--- a/src/ast/workgroup_decoration_test.cc
+++ b/src/ast/workgroup_decoration_test.cc
@@ -24,40 +24,84 @@
 using WorkgroupDecorationTest = TestHelper;
 
 TEST_F(WorkgroupDecorationTest, Creation_1param) {
-  auto* d = create<WorkgroupDecoration>(2);
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = d->values();
-  EXPECT_EQ(x, 2u);
-  EXPECT_EQ(y, 1u);
-  EXPECT_EQ(z, 1u);
+  auto* d = WorkgroupSize(2);
+  auto values = d->values();
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_TRUE(x_scalar);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+  EXPECT_EQ(values[1], nullptr);
+  EXPECT_EQ(values[2], nullptr);
 }
 TEST_F(WorkgroupDecorationTest, Creation_2param) {
-  auto* d = create<WorkgroupDecoration>(2, 4);
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = d->values();
-  EXPECT_EQ(x, 2u);
-  EXPECT_EQ(y, 4u);
-  EXPECT_EQ(z, 1u);
+  auto* d = WorkgroupSize(2, 4);
+  auto values = d->values();
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_TRUE(x_scalar);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+  ASSERT_TRUE(y_scalar);
+  ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  EXPECT_EQ(values[2], nullptr);
 }
 
 TEST_F(WorkgroupDecorationTest, Creation_3param) {
-  auto* d = create<WorkgroupDecoration>(2, 4, 6);
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = d->values();
-  EXPECT_EQ(x, 2u);
-  EXPECT_EQ(y, 4u);
-  EXPECT_EQ(z, 6u);
+  auto* d = WorkgroupSize(2, 4, 6);
+  auto values = d->values();
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_TRUE(x_scalar);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+  ASSERT_TRUE(y_scalar);
+  ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  ASSERT_NE(values[2], nullptr);
+  auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+  ASSERT_TRUE(z_scalar);
+  ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 6u);
+}
+
+TEST_F(WorkgroupDecorationTest, Creation_WithIdentifier) {
+  auto* d = WorkgroupSize(2, 4, "depth");
+  auto values = d->values();
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_TRUE(x_scalar);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+  ASSERT_TRUE(y_scalar);
+  ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  ASSERT_NE(values[2], nullptr);
+  auto* z_ident = values[2]->As<ast::IdentifierExpression>();
+  ASSERT_TRUE(z_ident);
+  EXPECT_EQ(Symbols().NameFor(z_ident->symbol()), "depth");
 }
 
 TEST_F(WorkgroupDecorationTest, ToStr) {
-  auto* d = create<WorkgroupDecoration>(2, 4, 6);
-  EXPECT_EQ(str(d), R"(WorkgroupDecoration{2 4 6}
+  auto* d = WorkgroupSize(2, "height");
+  EXPECT_EQ(str(d), R"(WorkgroupDecoration{
+  ScalarConstructor[not set]{2}
+  Identifier[not set]{height}
+}
 )");
 }
 
diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc
index 2c5f01f..62d95a7 100644
--- a/src/inspector/inspector_test.cc
+++ b/src/inspector/inspector_test.cc
@@ -828,10 +828,8 @@
 }
 
 TEST_F(InspectorGetEntryPointTest, NonDefaultWorkgroupSize) {
-  MakeEmptyBodyFunction("foo", ast::DecorationList{
-                                   Stage(ast::PipelineStage::kCompute),
-                                   create<ast::WorkgroupDecoration>(8u, 2u, 1u),
-                               });
+  MakeEmptyBodyFunction(
+      "foo", {Stage(ast::PipelineStage::kCompute), WorkgroupSize(8, 2, 1)});
 
   Inspector& inspector = Build();
 
diff --git a/src/program_builder.h b/src/program_builder.h
index 610f1eb..05acee4 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -59,6 +59,7 @@
 #include "src/ast/variable_decl_statement.h"
 #include "src/ast/vector.h"
 #include "src/ast/void.h"
+#include "src/ast/workgroup_decoration.h"
 #include "src/program.h"
 #include "src/program_id.h"
 #include "src/sem/array.h"
@@ -1914,6 +1915,36 @@
     return create<ast::StageDecoration>(source_, stage);
   }
 
+  /// Creates an ast::WorkgroupDecoration
+  /// @param x the x dimension expression
+  /// @returns the workgroup decoration pointer
+  template <typename EXPR_X>
+  ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x) {
+    return WorkgroupSize(std::forward<EXPR_X>(x), nullptr, nullptr);
+  }
+
+  /// Creates an ast::WorkgroupDecoration
+  /// @param x the x dimension expression
+  /// @param y the y dimension expression
+  /// @returns the workgroup decoration pointer
+  template <typename EXPR_X, typename EXPR_Y>
+  ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y) {
+    return WorkgroupSize(std::forward<EXPR_X>(x), std::forward<EXPR_Y>(y),
+                         nullptr);
+  }
+
+  /// Creates an ast::WorkgroupDecoration
+  /// @param x the x dimension expression
+  /// @param y the y dimension expression
+  /// @param z the z dimension expression
+  /// @returns the workgroup decoration pointer
+  template <typename EXPR_X, typename EXPR_Y, typename EXPR_Z>
+  ast::WorkgroupDecoration* WorkgroupSize(EXPR_X&& x, EXPR_Y&& y, EXPR_Z&& z) {
+    return create<ast::WorkgroupDecoration>(
+        source_, Expr(std::forward<EXPR_X>(x)), Expr(std::forward<EXPR_Y>(y)),
+        Expr(std::forward<EXPR_Z>(z)));
+  }
+
   /// Sets the current builder source to `src`
   /// @param src the Source used for future create() calls
   void SetSource(const Source& src) {
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index 9398034..8c91da6 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -3010,26 +3010,35 @@
 
   if (s == kWorkgroupSizeDecoration) {
     return expect_paren_block("workgroup_size decoration", [&]() -> Result {
-      uint32_t x;
-      uint32_t y = 1;
-      uint32_t z = 1;
+      ast::Expression* x = nullptr;
+      ast::Expression* y = nullptr;
+      ast::Expression* z = nullptr;
 
-      auto val = expect_nonzero_positive_sint("workgroup_size x parameter");
-      if (val.errored)
+      auto expr = primary_expression();
+      if (expr.errored) {
         return Failure::kErrored;
-      x = val.value;
+      } else if (!expr.matched) {
+        return add_error(peek(), "expected workgroup_size x parameter");
+      }
+      x = std::move(expr.value);
 
       if (match(Token::Type::kComma)) {
-        val = expect_nonzero_positive_sint("workgroup_size y parameter");
-        if (val.errored)
+        expr = primary_expression();
+        if (expr.errored) {
           return Failure::kErrored;
-        y = val.value;
+        } else if (!expr.matched) {
+          return add_error(peek(), "expected workgroup_size y parameter");
+        }
+        y = std::move(expr.value);
 
         if (match(Token::Type::kComma)) {
-          val = expect_nonzero_positive_sint("workgroup_size z parameter");
-          if (val.errored)
+          expr = primary_expression();
+          if (expr.errored) {
             return Failure::kErrored;
-          z = val.value;
+          } else if (!expr.matched) {
+            return add_error(peek(), "expected workgroup_size z parameter");
+          }
+          z = std::move(expr.value);
         }
       }
 
diff --git a/src/reader/wgsl/parser_impl_error_msg_test.cc b/src/reader/wgsl/parser_impl_error_msg_test.cc
index c9c7c83..ba54f60 100644
--- a/src/reader/wgsl/parser_impl_error_msg_test.cc
+++ b/src/reader/wgsl/parser_impl_error_msg_test.cc
@@ -325,74 +325,23 @@
 }
 
 TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXInvalid) {
-  EXPECT("[[workgroup_size(x)]] fn f() {}",
-         "test.wgsl:1:18 error: expected signed integer literal for "
-         "workgroup_size x parameter\n"
-         "[[workgroup_size(x)]] fn f() {}\n"
-         "                 ^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXNegative) {
-  EXPECT("[[workgroup_size(-1)]] fn f() {}",
-         "test.wgsl:1:18 error: workgroup_size x parameter must be greater "
-         "than 0\n"
-         "[[workgroup_size(-1)]] fn f() {}\n"
-         "                 ^^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeXZero) {
-  EXPECT("[[workgroup_size(0)]] fn f() {}",
-         "test.wgsl:1:18 error: workgroup_size x parameter must be greater "
-         "than 0\n"
-         "[[workgroup_size(0)]] fn f() {}\n"
+  EXPECT("[[workgroup_size(@)]] fn f() {}",
+         "test.wgsl:1:18 error: expected workgroup_size x parameter\n"
+         "[[workgroup_size(@)]] fn f() {}\n"
          "                 ^\n");
 }
 
 TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYInvalid) {
-  EXPECT("[[workgroup_size(1, x)]] fn f() {}",
-         "test.wgsl:1:21 error: expected signed integer literal for "
-         "workgroup_size y parameter\n"
-         "[[workgroup_size(1, x)]] fn f() {}\n"
-         "                    ^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYNegative) {
-  EXPECT("[[workgroup_size(1, -1)]] fn f() {}",
-         "test.wgsl:1:21 error: workgroup_size y parameter must be greater "
-         "than 0\n"
-         "[[workgroup_size(1, -1)]] fn f() {}\n"
-         "                    ^^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeYZero) {
-  EXPECT("[[workgroup_size(1, 0)]] fn f() {}",
-         "test.wgsl:1:21 error: workgroup_size y parameter must be greater "
-         "than 0\n"
-         "[[workgroup_size(1, 0)]] fn f() {}\n"
+  EXPECT("[[workgroup_size(1, @)]] fn f() {}",
+         "test.wgsl:1:21 error: expected workgroup_size y parameter\n"
+         "[[workgroup_size(1, @)]] fn f() {}\n"
          "                    ^\n");
 }
 
 TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZInvalid) {
-  EXPECT("[[workgroup_size(1, 2, x)]] fn f() {}",
-         "test.wgsl:1:24 error: expected signed integer literal for "
-         "workgroup_size z parameter\n"
-         "[[workgroup_size(1, 2, x)]] fn f() {}\n"
-         "                       ^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZNegative) {
-  EXPECT("[[workgroup_size(1, 2, -1)]] fn f() {}",
-         "test.wgsl:1:24 error: workgroup_size z parameter must be greater "
-         "than 0\n"
-         "[[workgroup_size(1, 2, -1)]] fn f() {}\n"
-         "                       ^^\n");
-}
-
-TEST_F(ParserImplErrorTest, FunctionDeclDecoWorkgroupSizeZZero) {
-  EXPECT("[[workgroup_size(1, 2, 0)]] fn f() {}",
-         "test.wgsl:1:24 error: workgroup_size z parameter must be greater "
-         "than 0\n"
-         "[[workgroup_size(1, 2, 0)]] fn f() {}\n"
+  EXPECT("[[workgroup_size(1, 2, @)]] fn f() {}",
+         "test.wgsl:1:24 error: expected workgroup_size z parameter\n"
+         "[[workgroup_size(1, 2, @)]] fn f() {}\n"
          "                       ^\n");
 }
 
diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc
index 3144df3..f08db28 100644
--- a/src/reader/wgsl/parser_impl_function_decl_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decl_test.cc
@@ -69,13 +69,25 @@
   ASSERT_EQ(decorations.size(), 1u);
   ASSERT_TRUE(decorations[0]->Is<ast::WorkgroupDecoration>());
 
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = decorations[0]->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 2u);
-  EXPECT_EQ(y, 3u);
-  EXPECT_EQ(z, 4u);
+  auto values = decorations[0]->As<ast::WorkgroupDecoration>()->values();
+
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(x_scalar, nullptr);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(y_scalar, nullptr);
+  ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
+
+  ASSERT_NE(values[2], nullptr);
+  auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(z_scalar, nullptr);
+  ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
 
   auto* body = f->body();
   ASSERT_EQ(body->size(), 1u);
@@ -84,7 +96,7 @@
 
 TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleEntries) {
   auto p = parser(R"(
-[[workgroup_size(2, 3, 4), workgroup_size(5, 6, 7)]]
+[[workgroup_size(2, 3, 4), stage(compute)]]
 fn main() { return; })");
   auto decos = p->decoration_list();
   EXPECT_FALSE(p->has_error()) << p->error();
@@ -104,20 +116,30 @@
   auto& decorations = f->decorations();
   ASSERT_EQ(decorations.size(), 2u);
 
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
   ASSERT_TRUE(decorations[0]->Is<ast::WorkgroupDecoration>());
-  std::tie(x, y, z) = decorations[0]->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 2u);
-  EXPECT_EQ(y, 3u);
-  EXPECT_EQ(z, 4u);
+  auto values = decorations[0]->As<ast::WorkgroupDecoration>()->values();
 
-  ASSERT_TRUE(decorations[1]->Is<ast::WorkgroupDecoration>());
-  std::tie(x, y, z) = decorations[1]->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 5u);
-  EXPECT_EQ(y, 6u);
-  EXPECT_EQ(z, 7u);
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(x_scalar, nullptr);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(y_scalar, nullptr);
+  ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
+
+  ASSERT_NE(values[2], nullptr);
+  auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(z_scalar, nullptr);
+  ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  ASSERT_TRUE(decorations[1]->Is<ast::StageDecoration>());
+  EXPECT_EQ(decorations[1]->As<ast::StageDecoration>()->value(),
+            ast::PipelineStage::kCompute);
 
   auto* body = f->body();
   ASSERT_EQ(body->size(), 1u);
@@ -127,7 +149,7 @@
 TEST_F(ParserImplTest, FunctionDecl_DecorationList_MultipleLists) {
   auto p = parser(R"(
 [[workgroup_size(2, 3, 4)]]
-[[workgroup_size(5, 6, 7)]]
+[[stage(compute)]]
 fn main() { return; })");
   auto decorations = p->decoration_list();
   EXPECT_FALSE(p->has_error()) << p->error();
@@ -147,20 +169,30 @@
   auto& decos = f->decorations();
   ASSERT_EQ(decos.size(), 2u);
 
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
   ASSERT_TRUE(decos[0]->Is<ast::WorkgroupDecoration>());
-  std::tie(x, y, z) = decos[0]->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 2u);
-  EXPECT_EQ(y, 3u);
-  EXPECT_EQ(z, 4u);
+  auto values = decos[0]->As<ast::WorkgroupDecoration>()->values();
 
-  ASSERT_TRUE(decos[1]->Is<ast::WorkgroupDecoration>());
-  std::tie(x, y, z) = decos[1]->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 5u);
-  EXPECT_EQ(y, 6u);
-  EXPECT_EQ(z, 7u);
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(x_scalar, nullptr);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(y_scalar, nullptr);
+  ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 3u);
+
+  ASSERT_NE(values[2], nullptr);
+  auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(z_scalar, nullptr);
+  ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  ASSERT_TRUE(decos[1]->Is<ast::StageDecoration>());
+  EXPECT_EQ(decos[1]->As<ast::StageDecoration>()->value(),
+            ast::PipelineStage::kCompute);
 
   auto* body = f->body();
   ASSERT_EQ(body->size(), 1u);
diff --git a/src/reader/wgsl/parser_impl_function_decoration_list_test.cc b/src/reader/wgsl/parser_impl_function_decoration_list_test.cc
index c2272c4..0b76ef2 100644
--- a/src/reader/wgsl/parser_impl_function_decoration_list_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decoration_list_test.cc
@@ -21,7 +21,7 @@
 namespace {
 
 TEST_F(ParserImplTest, DecorationList_Parses) {
-  auto p = parser("[[workgroup_size(2), workgroup_size(3, 4, 5)]]");
+  auto p = parser("[[workgroup_size(2), stage(compute)]]");
   auto decos = p->decoration_list();
   EXPECT_FALSE(p->has_error()) << p->error();
   EXPECT_FALSE(decos.errored);
@@ -33,18 +33,17 @@
   ASSERT_NE(deco_0, nullptr);
   ASSERT_NE(deco_1, nullptr);
 
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
   ASSERT_TRUE(deco_0->Is<ast::WorkgroupDecoration>());
-  std::tie(x, y, z) = deco_0->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 2u);
+  ast::Expression* x = deco_0->As<ast::WorkgroupDecoration>()->values()[0];
+  ASSERT_NE(x, nullptr);
+  auto* x_scalar = x->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(x_scalar, nullptr);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 2u);
 
-  ASSERT_TRUE(deco_1->Is<ast::WorkgroupDecoration>());
-  std::tie(x, y, z) = deco_1->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 3u);
-  EXPECT_EQ(y, 4u);
-  EXPECT_EQ(z, 5u);
+  ASSERT_TRUE(deco_1->Is<ast::StageDecoration>());
+  EXPECT_EQ(deco_1->As<ast::StageDecoration>()->value(),
+            ast::PipelineStage::kCompute);
 }
 
 TEST_F(ParserImplTest, DecorationList_Empty) {
@@ -85,14 +84,12 @@
 }
 
 TEST_F(ParserImplTest, DecorationList_BadDecoration) {
-  auto p = parser("[[workgroup_size()]]");
+  auto p = parser("[[stage()]]");
   auto decos = p->decoration_list();
   EXPECT_TRUE(p->has_error());
   EXPECT_TRUE(decos.errored);
   EXPECT_FALSE(decos.matched);
-  EXPECT_EQ(
-      p->error(),
-      "1:18: expected signed integer literal for workgroup_size x parameter");
+  EXPECT_EQ(p->error(), "1:9: invalid value for stage decoration");
 }
 
 TEST_F(ParserImplTest, DecorationList_MissingRightAttr) {
diff --git a/src/reader/wgsl/parser_impl_function_decoration_test.cc b/src/reader/wgsl/parser_impl_function_decoration_test.cc
index dc9fd1f..432f99e 100644
--- a/src/reader/wgsl/parser_impl_function_decoration_test.cc
+++ b/src/reader/wgsl/parser_impl_function_decoration_test.cc
@@ -32,13 +32,16 @@
   ASSERT_NE(func_deco, nullptr);
   ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
 
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 4u);
-  EXPECT_EQ(y, 1u);
-  EXPECT_EQ(z, 1u);
+  auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
+
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(x_scalar, nullptr);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  EXPECT_EQ(values[1], nullptr);
+  EXPECT_EQ(values[2], nullptr);
 }
 
 TEST_F(ParserImplTest, Decoration_Workgroup_2Param) {
@@ -52,13 +55,21 @@
   ASSERT_NE(func_deco, nullptr) << p->error();
   ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
 
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 4u);
-  EXPECT_EQ(y, 5u);
-  EXPECT_EQ(z, 1u);
+  auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
+
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(x_scalar, nullptr);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(y_scalar, nullptr);
+  ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 5u);
+
+  EXPECT_EQ(values[2], nullptr);
 }
 
 TEST_F(ParserImplTest, Decoration_Workgroup_3Param) {
@@ -72,13 +83,52 @@
   ASSERT_NE(func_deco, nullptr);
   ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
 
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = func_deco->As<ast::WorkgroupDecoration>()->values();
-  EXPECT_EQ(x, 4u);
-  EXPECT_EQ(y, 5u);
-  EXPECT_EQ(z, 6u);
+  auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
+
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(x_scalar, nullptr);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_scalar = values[1]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(y_scalar, nullptr);
+  ASSERT_TRUE(y_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(y_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 5u);
+
+  ASSERT_NE(values[2], nullptr);
+  auto* z_scalar = values[2]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(z_scalar, nullptr);
+  ASSERT_TRUE(z_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(z_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 6u);
+}
+
+TEST_F(ParserImplTest, Decoration_Workgroup_WithIdent) {
+  auto p = parser("workgroup_size(4, height)");
+  auto deco = p->decoration();
+  EXPECT_TRUE(deco.matched);
+  EXPECT_FALSE(deco.errored);
+  ASSERT_NE(deco.value, nullptr) << p->error();
+  ASSERT_FALSE(p->has_error());
+  auto* func_deco = deco.value->As<ast::Decoration>();
+  ASSERT_NE(func_deco, nullptr);
+  ASSERT_TRUE(func_deco->Is<ast::WorkgroupDecoration>());
+
+  auto values = func_deco->As<ast::WorkgroupDecoration>()->values();
+
+  ASSERT_NE(values[0], nullptr);
+  auto* x_scalar = values[0]->As<ast::ScalarConstructorExpression>();
+  ASSERT_NE(x_scalar, nullptr);
+  ASSERT_TRUE(x_scalar->literal()->Is<ast::IntLiteral>());
+  EXPECT_EQ(x_scalar->literal()->As<ast::IntLiteral>()->value_as_u32(), 4u);
+
+  ASSERT_NE(values[1], nullptr);
+  auto* y_ident = values[1]->As<ast::IdentifierExpression>();
+  ASSERT_NE(y_ident, nullptr);
+  EXPECT_EQ(p->builder().Symbols().NameFor(y_ident->symbol()), "height");
+
+  ASSERT_EQ(values[2], nullptr);
 }
 
 TEST_F(ParserImplTest, Decoration_Workgroup_TooManyValues) {
@@ -91,39 +141,6 @@
   EXPECT_EQ(p->error(), "1:23: expected ')' for workgroup_size decoration");
 }
 
-TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_X_Value) {
-  auto p = parser("workgroup_size(-2, 5, 6)");
-  auto deco = p->decoration();
-  EXPECT_FALSE(deco.matched);
-  EXPECT_TRUE(deco.errored);
-  EXPECT_EQ(deco.value, nullptr);
-  EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(p->error(),
-            "1:16: workgroup_size x parameter must be greater than 0");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_Y_Value) {
-  auto p = parser("workgroup_size(4, 0, 6)");
-  auto deco = p->decoration();
-  EXPECT_FALSE(deco.matched);
-  EXPECT_TRUE(deco.errored);
-  EXPECT_EQ(deco.value, nullptr);
-  EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(p->error(),
-            "1:19: workgroup_size y parameter must be greater than 0");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Invalid_Z_Value) {
-  auto p = parser("workgroup_size(4, 5, -3)");
-  auto deco = p->decoration();
-  EXPECT_FALSE(deco.matched);
-  EXPECT_TRUE(deco.errored);
-  EXPECT_EQ(deco.value, nullptr);
-  EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(p->error(),
-            "1:22: workgroup_size z parameter must be greater than 0");
-}
-
 TEST_F(ParserImplTest, Decoration_Workgroup_MissingLeftParam) {
   auto p = parser("workgroup_size 4, 5, 6)");
   auto deco = p->decoration();
@@ -151,9 +168,7 @@
   EXPECT_TRUE(deco.errored);
   EXPECT_EQ(deco.value, nullptr);
   EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(
-      p->error(),
-      "1:16: expected signed integer literal for workgroup_size x parameter");
+  EXPECT_EQ(p->error(), "1:16: expected workgroup_size x parameter");
 }
 
 TEST_F(ParserImplTest, Decoration_Workgroup_Missing_X_Value) {
@@ -163,9 +178,7 @@
   EXPECT_TRUE(deco.errored);
   EXPECT_EQ(deco.value, nullptr);
   EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(
-      p->error(),
-      "1:16: expected signed integer literal for workgroup_size x parameter");
+  EXPECT_EQ(p->error(), "1:16: expected workgroup_size x parameter");
 }
 
 TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Y_Comma) {
@@ -185,9 +198,7 @@
   EXPECT_TRUE(deco.errored);
   EXPECT_EQ(deco.value, nullptr);
   EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(
-      p->error(),
-      "1:19: expected signed integer literal for workgroup_size y parameter");
+  EXPECT_EQ(p->error(), "1:19: expected workgroup_size y parameter");
 }
 
 TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Z_Comma) {
@@ -207,45 +218,7 @@
   EXPECT_TRUE(deco.errored);
   EXPECT_EQ(deco.value, nullptr);
   EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(
-      p->error(),
-      "1:22: expected signed integer literal for workgroup_size z parameter");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Missing_X_Invalid) {
-  auto p = parser("workgroup_size(nan)");
-  auto deco = p->decoration();
-  EXPECT_FALSE(deco.matched);
-  EXPECT_TRUE(deco.errored);
-  EXPECT_EQ(deco.value, nullptr);
-  EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(
-      p->error(),
-      "1:16: expected signed integer literal for workgroup_size x parameter");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Y_Invalid) {
-  auto p = parser("workgroup_size(2, nan)");
-  auto deco = p->decoration();
-  EXPECT_FALSE(deco.matched);
-  EXPECT_TRUE(deco.errored);
-  EXPECT_EQ(deco.value, nullptr);
-  EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(
-      p->error(),
-      "1:19: expected signed integer literal for workgroup_size y parameter");
-}
-
-TEST_F(ParserImplTest, Decoration_Workgroup_Missing_Z_Invalid) {
-  auto p = parser("workgroup_size(2, 3, nan)");
-  auto deco = p->decoration();
-  EXPECT_FALSE(deco.matched);
-  EXPECT_TRUE(deco.errored);
-  EXPECT_EQ(deco.value, nullptr);
-  EXPECT_TRUE(p->has_error());
-  EXPECT_EQ(
-      p->error(),
-      "1:22: expected signed integer literal for workgroup_size z parameter");
+  EXPECT_EQ(p->error(), "1:22: expected workgroup_size z parameter");
 }
 
 TEST_F(ParserImplTest, Decoration_Stage) {
diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc
index aac4784..15f8f7e 100644
--- a/src/resolver/decoration_validation_test.cc
+++ b/src/resolver/decoration_validation_test.cc
@@ -94,7 +94,8 @@
     case DecorationKind::kStructBlock:
       return {builder.create<ast::StructBlockDecoration>(source)};
     case DecorationKind::kWorkgroup:
-      return {builder.create<ast::WorkgroupDecoration>(source, 1u, 1u, 1u)};
+      return {
+          builder.create<ast::WorkgroupDecoration>(source, builder.Expr(1))};
     case DecorationKind::kBindingAndGroup:
       return {builder.create<ast::BindingDecoration>(source, 1u),
               builder.create<ast::GroupDecoration>(source, 1u)};
@@ -664,7 +665,7 @@
 
 TEST_F(WorkgroupDecoration, NotAnEntryPoint) {
   Func("main", {}, ty.void_(), {},
-       {create<ast::WorkgroupDecoration>(Source{{12, 34}}, 1u)});
+       {create<ast::WorkgroupDecoration>(Source{{12, 34}}, Expr(1))});
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
@@ -675,7 +676,7 @@
 TEST_F(WorkgroupDecoration, NotAComputeShader) {
   Func("main", {}, ty.void_(), {},
        {Stage(ast::PipelineStage::kFragment),
-        create<ast::WorkgroupDecoration>(Source{{12, 34}}, 1u)});
+        create<ast::WorkgroupDecoration>(Source{{12, 34}}, Expr(1))});
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
@@ -685,9 +686,8 @@
 
 TEST_F(WorkgroupDecoration, MultipleAttributes) {
   Func(Source{{12, 34}}, "main", {}, ty.void_(), {},
-       {Stage(ast::PipelineStage::kCompute),
-        create<ast::WorkgroupDecoration>(1u),
-        create<ast::WorkgroupDecoration>(2u)});
+       {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1),
+        WorkgroupSize(2)});
 
   EXPECT_FALSE(r()->Resolve());
   EXPECT_EQ(r()->error(),
diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc
index 423270b..a542b9d 100644
--- a/src/resolver/function_validation_test.cc
+++ b/src/resolver/function_validation_test.cc
@@ -244,5 +244,127 @@
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 }
 
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_BadType) {
+  // [[stage(compute), workgroup_size(64.0)]
+  // fn main() {}
+
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(create<ast::ScalarConstructorExpression>(
+            Source{Source::Location{12, 34}}, Literal(64.f)))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameter must be a literal i32 or an "
+            "i32 module-scope constant");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Negative) {
+  // [[stage(compute), workgroup_size(-2)]
+  // fn main() {}
+
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(create<ast::ScalarConstructorExpression>(
+            Source{Source::Location{12, 34}}, Literal(-2)))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Literal_Zero) {
+  // [[stage(compute), workgroup_size(0)]
+  // fn main() {}
+
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(create<ast::ScalarConstructorExpression>(
+            Source{Source::Location{12, 34}}, Literal(0)))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_BadType) {
+  // let x = 64.0;
+  // [[stage(compute), workgroup_size(x)]
+  // fn main() {}
+  GlobalConst("x", ty.f32(), Expr(64.f));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameter must be a literal i32 or an "
+            "i32 module-scope constant");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Negative) {
+  // let x = -2;
+  // [[stage(compute), workgroup_size(x)]
+  // fn main() {}
+  GlobalConst("x", ty.i32(), Expr(-2));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_Const_Zero) {
+  // let x = 0;
+  // [[stage(compute), workgroup_size(x)]
+  // fn main() {}
+  GlobalConst("x", ty.i32(), Expr(0));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+       WorkgroupSize_Const_NestedZeroValueConstructor) {
+  // let x = i32(i32(i32()));
+  // [[stage(compute), workgroup_size(x)]
+  // fn main() {}
+  GlobalConst("x", ty.i32(),
+              Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32()))));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(
+      r()->error(),
+      "12:34 error: workgroup_size parameter must be a positive i32 value");
+}
+
+TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
+  // var<private> x = 0;
+  // [[stage(compute), workgroup_size(x)]
+  // fn main() {}
+  Global("x", ty.i32(), ast::StorageClass::kPrivate, Expr(64));
+  Func("main", {}, ty.void_(), {},
+       {Stage(ast::PipelineStage::kCompute),
+        WorkgroupSize(Expr(Source{Source::Location{12, 34}}, "x"))});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error: workgroup_size parameter must be a literal i32 or an "
+            "i32 module-scope constant");
+}
+
 }  // namespace
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 93266aa..227331b 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1295,10 +1295,79 @@
 
   if (auto* workgroup =
           ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations())) {
-    // TODO(crbug.com/tint/713): Handle non-literals.
-    info->workgroup_size[0].value = std::get<0>(workgroup->values());
-    info->workgroup_size[1].value = std::get<1>(workgroup->values());
-    info->workgroup_size[2].value = std::get<2>(workgroup->values());
+    auto values = workgroup->values();
+    for (int i = 0; i < 3; i++) {
+      // Each argument to this decoration can either be a literal, an
+      // identifier for a module-scope constants, or nullptr if not specified.
+
+      if (!values[i]) {
+        // Not specified, just use the default.
+        continue;
+      }
+
+      Mark(values[i]);
+
+      int32_t value = 0;
+      if (auto* ident = values[i]->As<ast::IdentifierExpression>()) {
+        // We have an identifier of a module-scope constant.
+        if (!Identifier(ident)) {
+          return false;
+        }
+
+        VariableInfo* var;
+        if (!variable_stack_.get(ident->symbol(), &var) ||
+            !(var->declaration->is_const() && var->type->Is<sem::I32>())) {
+          diagnostics_.add_error(
+              "workgroup_size parameter must be a literal i32 or an i32 "
+              "module-scope constant",
+              values[i]->source());
+          return false;
+        }
+
+        // Capture the constant if an [[override]] attribute is present.
+        if (ast::HasDecoration<ast::OverrideDecoration>(
+                var->declaration->decorations())) {
+          info->workgroup_size[i].overridable_const = var->declaration;
+        }
+
+        auto* constructor = var->declaration->constructor();
+        if (constructor) {
+          // Resolve the constructor expression to use as the default value.
+          if (!GetScalarConstExprValue(constructor, &value)) {
+            return false;
+          }
+        } else {
+          // No constructor means this value must be overriden by the user.
+          info->workgroup_size[i].value = 0;
+          continue;
+        }
+      } else if (auto* scalar =
+                     values[i]->As<ast::ScalarConstructorExpression>()) {
+        // We have a literal.
+        Mark(scalar->literal());
+
+        if (!scalar->literal()->Is<ast::IntLiteral>()) {
+          diagnostics_.add_error(
+              "workgroup_size parameter must be a literal i32 or an i32 "
+              "module-scope constant",
+              values[i]->source());
+          return false;
+        }
+
+        if (!GetScalarConstExprValue(scalar, &value)) {
+          return false;
+        }
+      }
+
+      // Validate and set the default value for this dimension.
+      if (value < 1) {
+        diagnostics_.add_error(
+            "workgroup_size parameter must be a positive i32 value",
+            values[i]->source());
+        return false;
+      }
+      info->workgroup_size[i].value = value;
+    }
   }
 
   if (!ValidateFunction(func, info)) {
@@ -3098,6 +3167,40 @@
   return true;
 }
 
+template <typename T>
+bool Resolver::GetScalarConstExprValue(ast::Expression* expr, T* result) {
+  if (auto* type_constructor = expr->As<ast::TypeConstructorExpression>()) {
+    if (type_constructor->values().size() == 0) {
+      // Zero-valued constructor.
+      *result = static_cast<T>(0);
+      return true;
+    } else if (type_constructor->values().size() == 1) {
+      // Recurse into the constructor argument expression.
+      return GetScalarConstExprValue(type_constructor->values()[0], result);
+    } else {
+      TINT_ICE(diagnostics_) << "malformed scalar type constructor";
+    }
+  } else if (auto* scalar = expr->As<ast::ScalarConstructorExpression>()) {
+    // Cast literal to result type.
+    if (auto* int_lit = scalar->literal()->As<ast::IntLiteral>()) {
+      *result = static_cast<T>(int_lit->value_as_u32());
+      return true;
+    } else if (auto* float_lit = scalar->literal()->As<ast::FloatLiteral>()) {
+      *result = static_cast<T>(float_lit->value());
+      return true;
+    } else if (auto* bool_lit = scalar->literal()->As<ast::BoolLiteral>()) {
+      *result = static_cast<T>(bool_lit->IsTrue());
+      return true;
+    } else {
+      TINT_ICE(diagnostics_) << "unhandled scalar constructor";
+    }
+  } else {
+    TINT_ICE(diagnostics_) << "unhandled constant expression";
+  }
+
+  return false;
+}
+
 template <typename F>
 bool Resolver::BlockScope(const ast::BlockStatement* block, F&& callback) {
   auto* sem_block = builder_->Sem().Get<sem::BlockStatement>(block);
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index ad428a8..8a622d2 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -323,6 +323,13 @@
                typ::Type type,
                const std::string& type_name);
 
+  /// Resolve the value of a scalar const_expr.
+  /// @param expr the expression
+  /// @param result pointer to the where the result will be stored
+  /// @returns true on success, false on error
+  template <typename T>
+  bool GetScalarConstExprValue(ast::Expression* expr, T* result);
+
   /// Constructs a new semantic BlockStatement with the given type and with
   /// #current_block_ as its parent, assigns this to #current_block_, and then
   /// calls `callback`. The original #current_block_ is restored on exit.
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 109b31a..41e15e2 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -26,6 +26,7 @@
 #include "src/ast/if_statement.h"
 #include "src/ast/intrinsic_texture_helper_test.h"
 #include "src/ast/loop_statement.h"
+#include "src/ast/override_decoration.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/stage_decoration.h"
 #include "src/ast/struct_block_decoration.h"
@@ -889,6 +890,8 @@
 }
 
 TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
+  // [[stage(compute)]]
+  // fn main() {}
   auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
 
   EXPECT_TRUE(r()->Resolve()) << r()->error();
@@ -905,9 +908,11 @@
 }
 
 TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
-  auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
-                    {Stage(ast::PipelineStage::kCompute),
-                     create<ast::WorkgroupDecoration>(8, 2, 3)});
+  // [[stage(compute), workgroup_size(8, 2, 3)]]
+  // fn main() {}
+  auto* func =
+      Func("main", ast::VariableList{}, ty.void_(), {},
+           {Stage(ast::PipelineStage::kCompute), WorkgroupSize(8, 2, 3)});
 
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 
@@ -922,6 +927,134 @@
   EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
 }
 
+TEST_F(ResolverTest, Function_WorkgroupSize_Consts) {
+  // let width = 16;
+  // let height = 8;
+  // let depth = 2;
+  // [[stage(compute), workgroup_size(width, height, depth)]]
+  // fn main() {}
+  GlobalConst("width", ty.i32(), Expr(16));
+  GlobalConst("height", ty.i32(), Expr(8));
+  GlobalConst("depth", ty.i32(), Expr(2));
+  auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+                    {Stage(ast::PipelineStage::kCompute),
+                     WorkgroupSize("width", "height", "depth")});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+
+  EXPECT_EQ(func_sem->workgroup_size()[0].value, 16u);
+  EXPECT_EQ(func_sem->workgroup_size()[1].value, 8u);
+  EXPECT_EQ(func_sem->workgroup_size()[2].value, 2u);
+  EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Consts_NestedInitializer) {
+  // let width = i32(i32(i32(8)));
+  // let height = i32(i32(i32(4)));
+  // [[stage(compute), workgroup_size(width, height)]]
+  // fn main() {}
+  GlobalConst("width", ty.i32(),
+              Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 8))));
+  GlobalConst("height", ty.i32(),
+              Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 4))));
+  auto* func = Func(
+      "main", ast::VariableList{}, ty.void_(), {},
+      {Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height")});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+
+  EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
+  EXPECT_EQ(func_sem->workgroup_size()[1].value, 4u);
+  EXPECT_EQ(func_sem->workgroup_size()[2].value, 1u);
+  EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
+  // [[override(0)]] let width = 16;
+  // [[override(1)]] let height = 8;
+  // [[override(2)]] let depth = 2;
+  // [[stage(compute), workgroup_size(width, height, depth)]]
+  // fn main() {}
+  auto* width = GlobalConst("width", ty.i32(), Expr(16), {Override(0)});
+  auto* height = GlobalConst("height", ty.i32(), Expr(8), {Override(1)});
+  auto* depth = GlobalConst("depth", ty.i32(), Expr(2), {Override(2)});
+  auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+                    {Stage(ast::PipelineStage::kCompute),
+                     WorkgroupSize("width", "height", "depth")});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+
+  EXPECT_EQ(func_sem->workgroup_size()[0].value, 16u);
+  EXPECT_EQ(func_sem->workgroup_size()[1].value, 8u);
+  EXPECT_EQ(func_sem->workgroup_size()[2].value, 2u);
+  EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, width);
+  EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
+  EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, depth);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
+  // [[override(0)]] let width : i32;
+  // [[override(1)]] let height : i32;
+  // [[override(2)]] let depth : i32;
+  // [[stage(compute), workgroup_size(width, height, depth)]]
+  // fn main() {}
+  auto* width = GlobalConst("width", ty.i32(), nullptr, {Override(0)});
+  auto* height = GlobalConst("height", ty.i32(), nullptr, {Override(1)});
+  auto* depth = GlobalConst("depth", ty.i32(), nullptr, {Override(2)});
+  auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+                    {Stage(ast::PipelineStage::kCompute),
+                     WorkgroupSize("width", "height", "depth")});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+
+  EXPECT_EQ(func_sem->workgroup_size()[0].value, 0u);
+  EXPECT_EQ(func_sem->workgroup_size()[1].value, 0u);
+  EXPECT_EQ(func_sem->workgroup_size()[2].value, 0u);
+  EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, width);
+  EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
+  EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, depth);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
+  // [[override(1)]] let height = 2;
+  // let depth = 3;
+  // [[stage(compute), workgroup_size(8, height, depth)]]
+  // fn main() {}
+  auto* height = GlobalConst("height", ty.i32(), Expr(2), {Override(0)});
+  GlobalConst("depth", ty.i32(), Expr(3));
+  auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+                    {Stage(ast::PipelineStage::kCompute),
+                     WorkgroupSize(8, "height", "depth")});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+
+  EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
+  EXPECT_EQ(func_sem->workgroup_size()[1].value, 2u);
+  EXPECT_EQ(func_sem->workgroup_size()[2].value, 3u);
+  EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, height);
+  EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
 TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
   auto* st = Structure("S", {Member("first_member", ty.i32()),
                              Member("second_member", ty.f32())});
diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc
index 68f55c7..502b60d 100644
--- a/src/writer/hlsl/generator_impl_function_test.cc
+++ b/src/writer/hlsl/generator_impl_function_test.cc
@@ -926,7 +926,7 @@
        },
        {
            Stage(ast::PipelineStage::kCompute),
-           create<ast::WorkgroupDecoration>(2u, 4u, 6u),
+           WorkgroupSize(2, 4, 6),
        });
 
   GeneratorImpl& gen = Build();
diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc
index c394926..91bc005 100644
--- a/src/writer/spirv/builder_function_decoration_test.cc
+++ b/src/writer/spirv/builder_function_decoration_test.cc
@@ -200,7 +200,7 @@
 TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize) {
   auto* func = Func("main", {}, ty.void_(), ast::StatementList{},
                     ast::DecorationList{
-                        create<ast::WorkgroupDecoration>(2u, 4u, 6u),
+                        WorkgroupSize(2, 4, 6),
                         Stage(ast::PipelineStage::kCompute),
                     });
 
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index 45a77b8..ef96c53 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -601,12 +601,28 @@
     first = false;
 
     if (auto* workgroup = deco->As<ast::WorkgroupDecoration>()) {
-      uint32_t x = 0;
-      uint32_t y = 0;
-      uint32_t z = 0;
-      std::tie(x, y, z) = workgroup->values();
-      out_ << "workgroup_size(" << std::to_string(x) << ", "
-           << std::to_string(y) << ", " << std::to_string(z) << ")";
+      auto values = workgroup->values();
+      out_ << "workgroup_size(";
+      for (int i = 0; i < 3; i++) {
+        if (values[i]) {
+          if (i > 0) {
+            out_ << ", ";
+          }
+          if (auto* ident = values[i]->As<ast::IdentifierExpression>()) {
+            if (!EmitIdentifier(ident)) {
+              return false;
+            }
+          } else if (auto* scalar =
+                         values[i]->As<ast::ScalarConstructorExpression>()) {
+            if (!EmitScalarConstructor(scalar)) {
+              return false;
+            }
+          } else {
+            TINT_ICE(diagnostics_) << "Unsupported workgroup_size expression";
+          }
+        }
+      }
+      out_ << ")";
     } else if (deco->Is<ast::StructBlockDecoration>()) {
       out_ << "block";
     } else if (auto* stage = deco->As<ast::StageDecoration>()) {
diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc
index fc69ccf..0e19f5a 100644
--- a/src/writer/wgsl/generator_impl_function_test.cc
+++ b/src/writer/wgsl/generator_impl_function_test.cc
@@ -69,13 +69,10 @@
 
 TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) {
   auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
-                    ast::StatementList{
-                        create<ast::DiscardStatement>(),
-                        Return(),
-                    },
+                    ast::StatementList{Return()},
                     ast::DecorationList{
                         Stage(ast::PipelineStage::kCompute),
-                        create<ast::WorkgroupDecoration>(2u, 4u, 6u),
+                        WorkgroupSize(2, 4, 6),
                     });
 
   GeneratorImpl& gen = Build();
@@ -85,20 +82,19 @@
   ASSERT_TRUE(gen.EmitFunction(func));
   EXPECT_EQ(gen.result(), R"(  [[stage(compute), workgroup_size(2, 4, 6)]]
   fn my_func() {
-    discard;
     return;
   }
 )");
 }
 
-TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) {
+TEST_F(WgslGeneratorImplTest,
+       Emit_Function_WithDecoration_WorkgroupSize_WithIdent) {
+  GlobalConst("height", ty.i32(), Expr(2));
   auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
-                    ast::StatementList{
-                        create<ast::DiscardStatement>(),
-                        Return(),
-                    },
+                    ast::StatementList{Return()},
                     ast::DecorationList{
-                        Stage(ast::PipelineStage::kFragment),
+                        Stage(ast::PipelineStage::kCompute),
+                        WorkgroupSize(2, "height"),
                     });
 
   GeneratorImpl& gen = Build();
@@ -106,9 +102,8 @@
   gen.increment_indent();
 
   ASSERT_TRUE(gen.EmitFunction(func));
-  EXPECT_EQ(gen.result(), R"(  [[stage(fragment)]]
+  EXPECT_EQ(gen.result(), R"(  [[stage(compute), workgroup_size(2, height)]]
   fn my_func() {
-    discard;
     return;
   }
 )");
diff --git a/test/samples/function.wgsl.expected.wgsl b/test/samples/function.wgsl.expected.wgsl
index d8a4c35..4cd4fb6 100644
--- a/test/samples/function.wgsl.expected.wgsl
+++ b/test/samples/function.wgsl.expected.wgsl
@@ -2,6 +2,6 @@
   return (((2.0 * 3.0) - 4.0) / 5.0);
 }
 
-[[stage(compute), workgroup_size(2, 1, 1)]]
+[[stage(compute), workgroup_size(2)]]
 fn ep() {
 }