Move workgroup_size property into sem::Function

The workgroup size should not be a property of the function in the
AST, and this lays the groundwork for allowing both literals and
module-scope constants to be used for this attribute.

Bug: tint:713
Change-Id: I014be879e2adb81cfc5b0ea0e221035fae626223
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51261
Auto-Submit: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index 24b0227..0896e7a 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -58,13 +58,6 @@
 
 Function::~Function() = default;
 
-std::tuple<uint32_t, uint32_t, uint32_t> Function::workgroup_size() const {
-  if (auto* workgroup = GetDecoration<WorkgroupDecoration>(decorations_)) {
-    return workgroup->values();
-  }
-  return {1, 1, 1};
-}
-
 PipelineStage Function::pipeline_stage() const {
   if (auto* stage = GetDecoration<StageDecoration>(decorations_)) {
     return stage->value();
diff --git a/src/ast/function.h b/src/ast/function.h
index fd91c7c..397e1c4 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -66,10 +66,6 @@
   /// @returns the decorations attached to this function
   const DecorationList& decorations() const { return decorations_; }
 
-  /// @returns the workgroup size {x, y, z} for the function. {1, 1, 1} will be
-  /// return if no workgroup size was set.
-  std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const;
-
   /// @returns the functions pipeline stage or None if not set
   PipelineStage pipeline_stage() const;
 
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 60c1525..44ba686 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -225,31 +225,6 @@
   EXPECT_EQ(f->get_last_statement(), nullptr);
 }
 
-TEST_F(FunctionTest, WorkgroupSize_NoneSet) {
-  auto* f = Func("func", VariableList{}, ty.void_(), StatementList{},
-                 DecorationList{});
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = f->workgroup_size();
-  EXPECT_EQ(x, 1u);
-  EXPECT_EQ(y, 1u);
-  EXPECT_EQ(z, 1u);
-}
-
-TEST_F(FunctionTest, WorkgroupSize) {
-  auto* f = Func("func", VariableList{}, ty.void_(), StatementList{},
-                 DecorationList{create<WorkgroupDecoration>(2u, 4u, 6u)});
-
-  uint32_t x = 0;
-  uint32_t y = 0;
-  uint32_t z = 0;
-  std::tie(x, y, z) = f->workgroup_size();
-  EXPECT_EQ(x, 2u);
-  EXPECT_EQ(y, 4u);
-  EXPECT_EQ(z, 6u);
-}
-
 using FunctionListTest = TestHelper;
 
 TEST_F(FunctionListTest, FindSymbol) {
diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc
index b8ac9b4..b687b22 100644
--- a/src/inspector/inspector.cc
+++ b/src/inspector/inspector.cc
@@ -198,8 +198,16 @@
     entry_point.name = program_->Symbols().NameFor(func->symbol());
     entry_point.remapped_name = program_->Symbols().NameFor(func->symbol());
     entry_point.stage = func->pipeline_stage();
-    std::tie(entry_point.workgroup_size_x, entry_point.workgroup_size_y,
-             entry_point.workgroup_size_z) = func->workgroup_size();
+
+    auto wgsize = sem->workgroup_size();
+    entry_point.workgroup_size_x = wgsize[0].value;
+    entry_point.workgroup_size_y = wgsize[1].value;
+    entry_point.workgroup_size_z = wgsize[2].value;
+    if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
+        wgsize[2].overridable_const) {
+      // TODO(crbug.com/tint/713): Handle overridable constants.
+      TINT_ASSERT(false);
+    }
 
     for (auto* param : sem->Parameters()) {
       AddEntryPointInOutVariables(
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index a80472b..93266aa 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -1287,6 +1287,20 @@
     Mark(deco);
   }
 
+  // Set work-group size defaults.
+  for (int i = 0; i < 3; i++) {
+    info->workgroup_size[i].value = 1;
+    info->workgroup_size[i].overridable_const = nullptr;
+  }
+
+  if (auto* workgroup =
+          ast::GetDecoration<ast::WorkgroupDecoration>(func->decorations())) {
+    // TODO(crbug.com/tint/713): Handle non-literals.
+    info->workgroup_size[0].value = std::get<0>(workgroup->values());
+    info->workgroup_size[1].value = std::get<1>(workgroup->values());
+    info->workgroup_size[2].value = std::get<2>(workgroup->values());
+  }
+
   if (!ValidateFunction(func, info)) {
     return false;
   }
@@ -2517,7 +2531,7 @@
         info->declaration, const_cast<sem::Type*>(info->return_type),
         remap_vars(info->parameters), remap_vars(info->referenced_module_vars),
         remap_vars(info->local_referenced_module_vars), info->return_statements,
-        ancestor_entry_points[func->symbol()]);
+        ancestor_entry_points[func->symbol()], info->workgroup_size);
     func_info_to_sem_func.emplace(info, sem_func);
     sem.Add(func, sem_func);
   }
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 76a722f..ad428a8 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -26,6 +26,7 @@
 #include "src/scope_stack.h"
 #include "src/sem/binding_point.h"
 #include "src/sem/block_statement.h"
+#include "src/sem/function.h"
 #include "src/sem/struct.h"
 #include "src/utils/unique_vector.h"
 
@@ -112,6 +113,7 @@
     std::vector<const ast::ReturnStatement*> return_statements;
     sem::Type* return_type = nullptr;
     std::string return_type_name;
+    std::array<sem::WorkgroupDimension, 3> workgroup_size;
 
     // List of transitive calls this function makes
     UniqueVector<FunctionInfo*> transitive_calls;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index e81be39..109b31a 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -32,6 +32,7 @@
 #include "src/ast/switch_statement.h"
 #include "src/ast/unary_op_expression.h"
 #include "src/ast/variable_decl_statement.h"
+#include "src/ast/workgroup_decoration.h"
 #include "src/resolver/resolver_test_helper.h"
 #include "src/sem/call.h"
 #include "src/sem/function.h"
@@ -887,6 +888,40 @@
   EXPECT_TRUE(func_sem->ReturnType()->Is<sem::F32>());
 }
 
+TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
+  auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+
+  EXPECT_EQ(func_sem->workgroup_size()[0].value, 1u);
+  EXPECT_EQ(func_sem->workgroup_size()[1].value, 1u);
+  EXPECT_EQ(func_sem->workgroup_size()[2].value, 1u);
+  EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
+TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
+  auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
+                    {Stage(ast::PipelineStage::kCompute),
+                     create<ast::WorkgroupDecoration>(8, 2, 3)});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+
+  EXPECT_EQ(func_sem->workgroup_size()[0].value, 8u);
+  EXPECT_EQ(func_sem->workgroup_size()[1].value, 2u);
+  EXPECT_EQ(func_sem->workgroup_size()[2].value, 3u);
+  EXPECT_EQ(func_sem->workgroup_size()[0].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[1].overridable_const, nullptr);
+  EXPECT_EQ(func_sem->workgroup_size()[2].overridable_const, nullptr);
+}
+
 TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
   auto* st = Structure("S", {Member("first_member", ty.i32()),
                              Member("second_member", ty.f32())});
diff --git a/src/sem/function.cc b/src/sem/function.cc
index 0c5e844..f33fc07 100644
--- a/src/sem/function.cc
+++ b/src/sem/function.cc
@@ -46,14 +46,16 @@
                    std::vector<const Variable*> referenced_module_vars,
                    std::vector<const Variable*> local_referenced_module_vars,
                    std::vector<const ast::ReturnStatement*> return_statements,
-                   std::vector<Symbol> ancestor_entry_points)
+                   std::vector<Symbol> ancestor_entry_points,
+                   std::array<WorkgroupDimension, 3> workgroup_size)
     : Base(return_type, GetParameters(parameters)),
       declaration_(declaration),
       parameters_(std::move(parameters)),
       referenced_module_vars_(std::move(referenced_module_vars)),
       local_referenced_module_vars_(std::move(local_referenced_module_vars)),
       return_statements_(std::move(return_statements)),
-      ancestor_entry_points_(std::move(ancestor_entry_points)) {}
+      ancestor_entry_points_(std::move(ancestor_entry_points)),
+      workgroup_size_(std::move(workgroup_size)) {}
 
 Function::~Function() = default;
 
diff --git a/src/sem/function.h b/src/sem/function.h
index a29e318..94cc02e 100644
--- a/src/sem/function.h
+++ b/src/sem/function.h
@@ -15,6 +15,7 @@
 #ifndef SRC_SEM_FUNCTION_H_
 #define SRC_SEM_FUNCTION_H_
 
+#include <array>
 #include <utility>
 #include <vector>
 
@@ -37,6 +38,16 @@
 
 class Variable;
 
+/// WorkgroupDimension describes the size of a single dimension of an entry
+/// point's workgroup size.
+struct WorkgroupDimension {
+  /// The size of this dimension.
+  uint32_t value;
+  /// A pipeline-overridable constant that overrides the size, or nullptr if
+  /// this dimension is not overridable.
+  const ast::Variable* overridable_const = nullptr;
+};
+
 /// Function holds the semantic information for function nodes.
 class Function : public Castable<Function, CallTarget> {
  public:
@@ -53,13 +64,15 @@
   /// @param return_statements the function return statements
   /// variables
   /// @param ancestor_entry_points the ancestor entry points
+  /// @param workgroup_size the workgroup size
   Function(ast::Function* declaration,
            Type* return_type,
            std::vector<const Variable*> parameters,
            std::vector<const Variable*> referenced_module_vars,
            std::vector<const Variable*> local_referenced_module_vars,
            std::vector<const ast::ReturnStatement*> return_statements,
-           std::vector<Symbol> ancestor_entry_points);
+           std::vector<Symbol> ancestor_entry_points,
+           std::array<WorkgroupDimension, 3> workgroup_size);
 
   /// Destructor
   ~Function() override;
@@ -148,6 +161,11 @@
   /// @returns true if `sym` is an ancestor entry point of this function
   bool HasAncestorEntryPoint(Symbol sym) const;
 
+  /// @returns the workgroup size {x, y, z} for the function.
+  const std::array<WorkgroupDimension, 3>& workgroup_size() const {
+    return workgroup_size_;
+  }
+
  private:
   VariableBindings ReferencedSamplerVariablesImpl(ast::SamplerKind kind) const;
   VariableBindings ReferencedSampledTextureVariablesImpl(
@@ -159,6 +177,7 @@
   std::vector<const Variable*> const local_referenced_module_vars_;
   std::vector<const ast::ReturnStatement*> const return_statements_;
   std::vector<Symbol> const ancestor_entry_points_;
+  std::array<WorkgroupDimension, 3> workgroup_size_;
 };
 
 }  // namespace sem
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index e894106..b516533 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -1989,12 +1989,19 @@
   make_indent(out);
 
   current_ep_sym_ = func->symbol();
+  auto* func_sem = builder_.Sem().Get(func);
 
   if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
-    uint32_t x = 0;
-    uint32_t y = 0;
-    uint32_t z = 0;
-    std::tie(x, y, z) = func->workgroup_size();
+    auto wgsize = func_sem->workgroup_size();
+    if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
+        wgsize[2].overridable_const) {
+      // TODO(crbug.com/tint/713): Handle overridable constants.
+      TINT_UNIMPLEMENTED(builder_.Diagnostics())
+          << "pipeline-overridable workgroup sizes are not implemented";
+    }
+    uint32_t x = wgsize[0].value;
+    uint32_t y = wgsize[1].value;
+    uint32_t z = wgsize[2].value;
     out << "[numthreads(" << std::to_string(x) << ", " << std::to_string(y)
         << ", " << std::to_string(z) << ")]" << std::endl;
     make_indent(out);
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 22f4a6a..60c159c 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -435,23 +435,30 @@
 }
 
 bool Builder::GenerateExecutionModes(ast::Function* func, uint32_t id) {
+  auto* func_sem = builder_.Sem().Get(func);
+
   // WGSL fragment shader origin is upper left
   if (func->pipeline_stage() == ast::PipelineStage::kFragment) {
     push_execution_mode(
         spv::Op::OpExecutionMode,
         {Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)});
   } else if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
-    uint32_t x = 0;
-    uint32_t y = 0;
-    uint32_t z = 0;
-    std::tie(x, y, z) = func->workgroup_size();
+    auto& wgsize = func_sem->workgroup_size();
+    if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
+        wgsize[2].overridable_const) {
+      // TODO(crbug.com/tint/713): Handle overridable constants.
+      TINT_UNIMPLEMENTED(builder_.Diagnostics())
+          << "pipeline-overridable workgroup sizes are not implemented";
+    }
+    uint32_t x = wgsize[0].value;
+    uint32_t y = wgsize[1].value;
+    uint32_t z = wgsize[2].value;
     push_execution_mode(
         spv::Op::OpExecutionMode,
         {Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
          Operand::Int(x), Operand::Int(y), Operand::Int(z)});
   }
 
-  auto* func_sem = builder_.Sem().Get(func);
   for (auto builtin : func_sem->ReferencedBuiltinVariables()) {
     if (builtin.second->value() == ast::Builtin::kFragDepth) {
       push_execution_mode(