resolver: Track global uses in function decorations

Fixed: tint:1320
Change-Id: Ib92c37d4de0641d11e508be4d8e05d641e808be9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/70662
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/docs/origin-trial-changes.md b/docs/origin-trial-changes.md
index 3437a84..d34c2bd 100644
--- a/docs/origin-trial-changes.md
+++ b/docs/origin-trial-changes.md
@@ -16,6 +16,10 @@
 * The `dot()` builtin now supports integer vector types.
 * Identifiers can now start with a single leading underscore.  [tint:1292](https://crbug.com/tint/1292)
 
+### Fixes
+
+* Fixed an issue where using a module-scoped `let` in a `workgroup_size` may result in a compilation error. [tint:1320](https://crbug.com/tint/1320)
+
 ## Changes for M97
 
 ### Breaking Changes
diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc
index af3b097..302ad88 100644
--- a/src/resolver/function_validation_test.cc
+++ b/src/resolver/function_validation_test.cc
@@ -427,15 +427,26 @@
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_ConstU32) {
   // let x = 4u;
   // let x = 8u;
-  // [[stage(compute), workgroup_size(x, y, 16u]
+  // [[stage(compute), workgroup_size(x, y, 16u)]]
   // fn main() {}
-  GlobalConst("x", ty.u32(), Expr(4u));
-  GlobalConst("y", ty.u32(), Expr(8u));
-  Func("main", {}, ty.void_(), {},
-       {Stage(ast::PipelineStage::kCompute),
-        WorkgroupSize(Expr("x"), Expr("y"), Expr(16u))});
+  auto* x = GlobalConst("x", ty.u32(), Expr(4u));
+  auto* y = GlobalConst("y", ty.u32(), Expr(8u));
+  auto* func = Func("main", {}, ty.void_(), {},
+                    {Stage(ast::PipelineStage::kCompute),
+                     WorkgroupSize(Expr("x"), Expr("y"), Expr(16u))});
 
   ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* sem_func = Sem().Get(func);
+  auto* sem_x = Sem().Get<sem::GlobalVariable>(x);
+  auto* sem_y = Sem().Get<sem::GlobalVariable>(y);
+
+  ASSERT_NE(sem_func, nullptr);
+  ASSERT_NE(sem_x, nullptr);
+  ASSERT_NE(sem_y, nullptr);
+
+  EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_x));
+  EXPECT_TRUE(sem_func->DirectlyReferencedGlobals().contains(sem_y));
 }
 
 TEST_F(ResolverFunctionValidationTest, WorkgroupSize_GoodType_U32) {
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 0e0abc9..3eb96a3 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -634,21 +634,19 @@
     }
   }
 
-  sem::WorkgroupSize ws{};
-  if (!WorkgroupSizeFor(decl, ws)) {
+  auto* func = builder_->create<sem::Function>(decl, return_type, parameters);
+  builder_->Sem().Add(decl, func);
+
+  TINT_SCOPED_ASSIGNMENT(current_function_, func);
+
+  if (!WorkgroupSize(decl)) {
     return nullptr;
   }
 
-  auto* func =
-      builder_->create<sem::Function>(decl, return_type, parameters, ws);
-  builder_->Sem().Add(decl, func);
-
   if (decl->IsEntryPoint()) {
     entry_points_.emplace_back(func);
   }
 
-  TINT_SCOPED_ASSIGNMENT(current_function_, func);
-
   if (decl->body) {
     Mark(decl->body);
     if (current_compound_statement_) {
@@ -692,9 +690,9 @@
   return func;
 }
 
-bool Resolver::WorkgroupSizeFor(const ast::Function* func,
-                                sem::WorkgroupSize& ws) {
+bool Resolver::WorkgroupSize(const ast::Function* func) {
   // Set work-group size defaults.
+  sem::WorkgroupSize ws;
   for (int i = 0; i < 3; i++) {
     ws[i].value = 1;
     ws[i].overridable_const = nullptr;
@@ -790,6 +788,8 @@
     ws[i].value = is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32)
                          : value.Elements()[0].u32;
   }
+
+  current_function_->SetWorkgroupSize(std::move(ws));
   return true;
 }
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index cca4440..37dea95 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -282,8 +282,9 @@
   bool IsValidationEnabled(const ast::DecorationList& decorations,
                            ast::DisabledValidation validation) const;
 
-  /// Resolves the WorkgroupSize for the given function
-  bool WorkgroupSizeFor(const ast::Function*, sem::WorkgroupSize& ws);
+  /// Resolves the WorkgroupSize for the given function, assigning it to
+  /// current_function_
+  bool WorkgroupSize(const ast::Function*);
 
   /// @returns the sem::Type for the ast::Type `ty`, building it if it
   /// hasn't been constructed already. If an error is raised, nullptr is
diff --git a/src/sem/function.cc b/src/sem/function.cc
index 709dd03..4e95ff6 100644
--- a/src/sem/function.cc
+++ b/src/sem/function.cc
@@ -30,11 +30,11 @@
 
 Function::Function(const ast::Function* declaration,
                    Type* return_type,
-                   std::vector<Parameter*> parameters,
-                   sem::WorkgroupSize workgroup_size)
+                   std::vector<Parameter*> parameters)
     : Base(return_type, utils::ToConstPtrVec(parameters)),
       declaration_(declaration),
-      workgroup_size_(std::move(workgroup_size)) {
+      workgroup_size_{WorkgroupDimension{1}, WorkgroupDimension{1},
+                      WorkgroupDimension{1}} {
   for (auto* parameter : parameters) {
     parameter->SetOwner(this);
   }
diff --git a/src/sem/function.h b/src/sem/function.h
index d8a58a8..ea834a7 100644
--- a/src/sem/function.h
+++ b/src/sem/function.h
@@ -62,11 +62,9 @@
   /// @param declaration the ast::Function
   /// @param return_type the return type of the function
   /// @param parameters the parameters to the function
-  /// @param workgroup_size the workgroup size
   Function(const ast::Function* declaration,
            Type* return_type,
-           std::vector<Parameter*> parameters,
-           sem::WorkgroupSize workgroup_size);
+           std::vector<Parameter*> parameters);
 
   /// Destructor
   ~Function() override;
@@ -77,6 +75,12 @@
   /// @returns the workgroup size {x, y, z} for the function.
   const sem::WorkgroupSize& WorkgroupSize() const { return workgroup_size_; }
 
+  /// Sets the workgroup size {x, y, z} for the function.
+  /// @param workgroup_size the new workgroup size of the function
+  void SetWorkgroupSize(sem::WorkgroupSize workgroup_size) {
+    workgroup_size_ = std::move(workgroup_size);
+  }
+
   /// @returns all directly referenced global variables
   const utils::UniqueVector<const GlobalVariable*>& DirectlyReferencedGlobals()
       const {
@@ -243,8 +247,8 @@
       bool multisampled) const;
 
   const ast::Function* const declaration_;
-  const sem::WorkgroupSize workgroup_size_;
 
+  sem::WorkgroupSize workgroup_size_;
   utils::UniqueVector<const GlobalVariable*> directly_referenced_globals_;
   utils::UniqueVector<const GlobalVariable*> transitively_referenced_globals_;
   utils::UniqueVector<const Function*> transitively_called_functions_;
diff --git a/src/transform/single_entry_point_test.cc b/src/transform/single_entry_point_test.cc
index cdab094..b2193b1 100644
--- a/src/transform/single_entry_point_test.cc
+++ b/src/transform/single_entry_point_test.cc
@@ -236,6 +236,26 @@
   EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(SingleEntryPointTest, WorkgroupSizeLetPreserved) {
+  auto* src = R"(
+let size : i32 = 1;
+
+[[stage(compute), workgroup_size(size)]]
+fn main() {
+}
+)";
+
+  auto* expect = src;
+
+  SingleEntryPoint::Config cfg("main");
+
+  DataMap data;
+  data.Add<SingleEntryPoint::Config>(cfg);
+  auto got = Run<SingleEntryPoint>(src, data);
+
+  EXPECT_EQ(expect, str(got));
+}
+
 TEST_F(SingleEntryPointTest, OverridableConstants) {
   auto* src = R"(
 [[override(1001)]] let c1 : u32 = 1u;