[spirv-reader]: Fix replacement of wgsize with multiple entry points
When multiple entry points use workgroup size, then
take care to replace each use of the global variable
only once.
Change-Id: I2a1d31fe4bc6805e37a444eb7f98b0dc211e1fdd
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/242794
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
index 63308cb..0579c8b 100644
--- a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
@@ -55,6 +55,9 @@
/// Map from function to the name of its workgroup_id parameter.
Hashmap<const ast::Function*, Symbol, 8> func_to_param;
+ /// The identifier expressions that have been replaced, other than
+ /// the assignments in entry points.
+ Hashset<const ast::Expression*, 16> replaced_idexpr;
/// Constructor
/// @param program the source program
@@ -63,21 +66,38 @@
/// Runs the transform.
/// @returns the new program
ApplyResult Run() {
+ struct Root {
+ const ast::Function* entry_point;
+ const ast::Parameter* param;
+ };
+
// Process all entry points in the module, looking for workgroup_id builtin parameters.
bool made_changes = false;
+ Vector<Root, 4> roots;
for (auto* func : src.AST().Functions()) {
if (func->IsEntryPoint()) {
for (auto* param : func->params) {
if (auto* builtin =
ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
if (builtin->builtin == core::BuiltinValue::kWorkgroupId) {
- ProcessBuiltin(func, param);
- made_changes = true;
+ roots.Emplace(func, param);
}
}
}
}
}
+ // Record the name of the parameter for all entry point functions
+ // that use the workgroup size builtin. Do this before traversing
+ // call sites of uses of the global wgsize variable.
+ for (const auto& root : roots) {
+ // Record the name of the parameter for the entry point function.
+ func_to_param.Add(root.entry_point, ctx.Clone(root.param->name->symbol));
+ made_changes = true;
+ }
+ for (const auto& root : roots) {
+ ProcessBuiltin(root.entry_point, root.param);
+ }
+
if (!made_changes) {
return SkipTransform;
}
@@ -90,9 +110,6 @@
/// @param ep the entry point function
/// @param builtin the builtin parameter
void ProcessBuiltin(const ast::Function* ep, const ast::Parameter* builtin) {
- // Record the name of the parameter for the entry point function.
- func_to_param.Add(ep, ctx.Clone(builtin->name->symbol));
-
// The reader should only produce a single use of the parameter which assigns to a global.
const auto& users = sem.Get(builtin)->Users();
TINT_ASSERT(users.Length() == 1u);
@@ -122,9 +139,20 @@
// Skip the assignment, which will be removed.
continue;
}
- auto param = GetParameter(user->Stmt()->Function()->Declaration(),
- lhs->Variable()->Declaration()->type);
- ctx.Replace(user->Declaration(), b.Expr(param));
+ auto* using_func = user->Stmt()->Function()->Declaration();
+ if (using_func->IsEntryPoint() && using_func != ep) {
+ // The other entry point will update its own uses.
+ continue;
+ }
+
+ // This use might be reachable from more than one entry point.
+ // Only replace it once. If we replaced it every time, then
+ // all but the last replacement identifier expressions would
+ // be orphaned.
+ if (replaced_idexpr.Add(user->Declaration())) {
+ auto param = GetParameter(using_func, lhs->Variable()->Declaration()->type);
+ ctx.Replace(user->Declaration(), b.Expr(param));
+ }
}
// Remove the global variable and the assignment to it.
@@ -144,7 +172,8 @@
// Recursively update all callsites to pass the workgroup_id as an argument.
for (auto* callsite : sem.Get(func)->CallSites()) {
- auto param = GetParameter(callsite->Stmt()->Function()->Declaration(), type);
+ auto caller_func = callsite->Stmt()->Function()->Declaration();
+ auto param = GetParameter(caller_func, type);
ctx.InsertBack(callsite->Declaration()->args, b.Expr(param));
}
diff --git a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
index 3c1bbcd..daaaf2d 100644
--- a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
+++ b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
@@ -404,5 +404,173 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(PassWorkgroupIdAsArgumentTest, TwoEntryPoints_UsedInBoth) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner() {
+ if (wgid.x == 0u) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+
+fn other_inner() {
+ if (wgid.x == 2u) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ other_inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3u) {
+ if ((tint_wgid.x == 0u)) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+
+fn other_inner(tint_wgid_1 : vec3u) {
+ if ((tint_wgid_1.x == 2u)) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+ other_inner(wgid_param);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, TwoEntryPoints_UsedInCommon) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner() {
+ if (wgid.x == 0u) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3u) {
+ if ((tint_wgid.x == 0u)) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, TwoEntryPoints_UsedInOnlyOne) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner() {
+ if (wgid.x == 0u) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+
+fn other_inner() {
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ other_inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3u) {
+ if ((tint_wgid.x == 0u)) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+
+fn other_inner() {
+}
+
+@compute @workgroup_size(64)
+fn other(@builtin(workgroup_id) wgid_param : vec3u) {
+ other_inner();
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::spirv::reader