[spirv-reader]: Fix replacement of wgsize with multiple entry points

When multiple entry points use workgroup size, then
take care to replace each use of the global variable
only once.

Change-Id: I2a1d31fe4bc6805e37a444eb7f98b0dc211e1fdd
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/242794
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
index 63308cb..0579c8b 100644
--- a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
@@ -55,6 +55,9 @@
 
     /// Map from function to the name of its workgroup_id parameter.
     Hashmap<const ast::Function*, Symbol, 8> func_to_param;
+    /// The identifier expressions that have been replaced, other than
+    /// the assignments in entry points.
+    Hashset<const ast::Expression*, 16> replaced_idexpr;
 
     /// Constructor
     /// @param program the source program
@@ -63,21 +66,38 @@
     /// Runs the transform.
     /// @returns the new program
     ApplyResult Run() {
+        struct Root {
+            const ast::Function* entry_point;
+            const ast::Parameter* param;
+        };
+
         // Process all entry points in the module, looking for workgroup_id builtin parameters.
         bool made_changes = false;
+        Vector<Root, 4> roots;
         for (auto* func : src.AST().Functions()) {
             if (func->IsEntryPoint()) {
                 for (auto* param : func->params) {
                     if (auto* builtin =
                             ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
                         if (builtin->builtin == core::BuiltinValue::kWorkgroupId) {
-                            ProcessBuiltin(func, param);
-                            made_changes = true;
+                            roots.Emplace(func, param);
                         }
                     }
                 }
             }
         }
+        // Record the name of the parameter for all entry point functions
+        // that use the workgroup size builtin.  Do this before traversing
+        // call sites of uses of the global wgsize variable.
+        for (const auto& root : roots) {
+            // Record the name of the parameter for the entry point function.
+            func_to_param.Add(root.entry_point, ctx.Clone(root.param->name->symbol));
+            made_changes = true;
+        }
+        for (const auto& root : roots) {
+            ProcessBuiltin(root.entry_point, root.param);
+        }
+
         if (!made_changes) {
             return SkipTransform;
         }
@@ -90,9 +110,6 @@
     /// @param ep the entry point function
     /// @param builtin the builtin parameter
     void ProcessBuiltin(const ast::Function* ep, const ast::Parameter* builtin) {
-        // Record the name of the parameter for the entry point function.
-        func_to_param.Add(ep, ctx.Clone(builtin->name->symbol));
-
         // The reader should only produce a single use of the parameter which assigns to a global.
         const auto& users = sem.Get(builtin)->Users();
         TINT_ASSERT(users.Length() == 1u);
@@ -122,9 +139,20 @@
                 // Skip the assignment, which will be removed.
                 continue;
             }
-            auto param = GetParameter(user->Stmt()->Function()->Declaration(),
-                                      lhs->Variable()->Declaration()->type);
-            ctx.Replace(user->Declaration(), b.Expr(param));
+            auto* using_func = user->Stmt()->Function()->Declaration();
+            if (using_func->IsEntryPoint() && using_func != ep) {
+                // The other entry point will update its own uses.
+                continue;
+            }
+
+            // This use might be reachable from more than one entry point.
+            // Only replace it once. If we replaced it every time, then
+            // all but the last replacement identifier expressions would
+            // be orphaned.
+            if (replaced_idexpr.Add(user->Declaration())) {
+                auto param = GetParameter(using_func, lhs->Variable()->Declaration()->type);
+                ctx.Replace(user->Declaration(), b.Expr(param));
+            }
         }
 
         // Remove the global variable and the assignment to it.
@@ -144,7 +172,8 @@
 
             // Recursively update all callsites to pass the workgroup_id as an argument.
             for (auto* callsite : sem.Get(func)->CallSites()) {
-                auto param = GetParameter(callsite->Stmt()->Function()->Declaration(), type);
+                auto caller_func = callsite->Stmt()->Function()->Declaration();
+                auto param = GetParameter(caller_func, type);
                 ctx.InsertBack(callsite->Declaration()->args, b.Expr(param));
             }
 
diff --git a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
index 3c1bbcd..daaaf2d 100644
--- a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
@@ -404,5 +404,173 @@
     EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(PassWorkgroupIdAsArgumentTest, TwoEntryPoints_UsedInBoth) {
+    auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner() {
+  if (wgid.x == 0u) {
+    workgroupBarrier();
+  }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+  wgid = wgid_param;
+  inner();
+}
+
+fn other_inner() {
+  if (wgid.x == 2u) {
+    workgroupBarrier();
+  }
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+  wgid = wgid_param;
+  other_inner();
+}
+)";
+
+    auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3u) {
+  if ((tint_wgid.x == 0u)) {
+    workgroupBarrier();
+  }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+  inner(wgid_param);
+}
+
+fn other_inner(tint_wgid_1 : vec3u) {
+  if ((tint_wgid_1.x == 2u)) {
+    workgroupBarrier();
+  }
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+  other_inner(wgid_param);
+}
+)";
+
+    auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, TwoEntryPoints_UsedInCommon) {
+    auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner() {
+  if (wgid.x == 0u) {
+    workgroupBarrier();
+  }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+  wgid = wgid_param;
+  inner();
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+  wgid = wgid_param;
+  inner();
+}
+)";
+
+    auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3u) {
+  if ((tint_wgid.x == 0u)) {
+    workgroupBarrier();
+  }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+  inner(wgid_param);
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+  inner(wgid_param);
+}
+)";
+
+    auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, TwoEntryPoints_UsedInOnlyOne) {
+    auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner() {
+  if (wgid.x == 0u) {
+    workgroupBarrier();
+  }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+  wgid = wgid_param;
+  inner();
+}
+
+fn other_inner() {
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+  wgid = wgid_param;
+  other_inner();
+}
+)";
+
+    auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3u) {
+  if ((tint_wgid.x == 0u)) {
+    workgroupBarrier();
+  }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+  inner(wgid_param);
+}
+
+fn other_inner() {
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+  other_inner();
+}
+)";
+
+    auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+    EXPECT_EQ(expect, str(got));
+}
+
 }  // namespace
 }  // namespace tint::spirv::reader