[spirv-reader] Pass workgroup_id as argument
To avoid uniformity issues, run a transform to pass workgroup_id
builtin values down the callstack instead of storing and re-loading
them from a module-scope private variable.
This requires that we temporarily disable the uniformity analysis in
the SPIR-V reader until this transform has run.
Fixed: tint:2031
Change-Id: If6d4ca6ab8b6561828fffe3ff87423c779faafed
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/164940
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel b/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
index 4bb7e6a..0fdbcd1 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.bazel
@@ -43,12 +43,14 @@
"decompose_strided_array.cc",
"decompose_strided_matrix.cc",
"fold_trivial_lets.cc",
+ "pass_workgroup_id_as_argument.cc",
],
hdrs = [
"atomics.h",
"decompose_strided_array.h",
"decompose_strided_matrix.h",
"fold_trivial_lets.h",
+ "pass_workgroup_id_as_argument.h",
],
deps = [
"//src/tint/api/common",
@@ -88,6 +90,7 @@
"decompose_strided_array_test.cc",
"decompose_strided_matrix_test.cc",
"fold_trivial_lets_test.cc",
+ "pass_workgroup_id_as_argument_test.cc",
],
deps = [
"//src/tint/api/common",
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake b/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
index 6d3870b..1f1a04c 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.cmake
@@ -49,6 +49,8 @@
lang/spirv/reader/ast_lower/decompose_strided_matrix.h
lang/spirv/reader/ast_lower/fold_trivial_lets.cc
lang/spirv/reader/ast_lower/fold_trivial_lets.h
+ lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
+ lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h
)
tint_target_add_dependencies(tint_lang_spirv_reader_ast_lower lib
@@ -91,6 +93,7 @@
lang/spirv/reader/ast_lower/decompose_strided_array_test.cc
lang/spirv/reader/ast_lower/decompose_strided_matrix_test.cc
lang/spirv/reader/ast_lower/fold_trivial_lets_test.cc
+ lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
)
tint_target_add_dependencies(tint_lang_spirv_reader_ast_lower_test test
diff --git a/src/tint/lang/spirv/reader/ast_lower/BUILD.gn b/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
index c2ded34..9381efc 100644
--- a/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
+++ b/src/tint/lang/spirv/reader/ast_lower/BUILD.gn
@@ -52,6 +52,8 @@
"decompose_strided_matrix.h",
"fold_trivial_lets.cc",
"fold_trivial_lets.h",
+ "pass_workgroup_id_as_argument.cc",
+ "pass_workgroup_id_as_argument.h",
]
deps = [
"${tint_src_dir}/api/common",
@@ -91,6 +93,7 @@
"decompose_strided_array_test.cc",
"decompose_strided_matrix_test.cc",
"fold_trivial_lets_test.cc",
+ "pass_workgroup_id_as_argument_test.cc",
]
deps = [
"${tint_src_dir}:gmock_and_gtest",
diff --git a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
new file mode 100644
index 0000000..cfaf300
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.cc
@@ -0,0 +1,163 @@
+// Copyright 2023 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h"
+
+#include <utility>
+
+#include "src/tint/lang/wgsl/program/clone_context.h"
+#include "src/tint/lang/wgsl/program/program_builder.h"
+#include "src/tint/lang/wgsl/resolver/resolve.h"
+#include "src/tint/lang/wgsl/sem/function.h"
+#include "src/tint/lang/wgsl/sem/statement.h"
+#include "src/tint/utils/containers/hashmap.h"
+
+using namespace tint::core::fluent_types; // NOLINT
+
+TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::PassWorkgroupIdAsArgument);
+
+namespace tint::spirv::reader {
+
+/// PIMPL state for the transform.
+struct PassWorkgroupIdAsArgument::State {
+ /// The source program
+ const Program& src;
+ /// The target program builder
+ ProgramBuilder b;
+ /// The clone context
+ program::CloneContext ctx = {&b, &src, /* auto_clone_symbols */ true};
+ /// The semantic info.
+ const sem::Info& sem = src.Sem();
+
+ /// Map from function to the name of its workgroup_id parameter.
+ Hashmap<const ast::Function*, Symbol, 8> func_to_param;
+
+ /// Constructor
+ /// @param program the source program
+ explicit State(const Program& program) : src(program) {}
+
+ /// Runs the transform.
+ /// @returns the new program
+ ApplyResult Run() {
+ // Process all entry points in the module, looking for workgroup_id builtin parameters.
+ bool made_changes = false;
+ for (auto* func : src.AST().Functions()) {
+ if (func->IsEntryPoint()) {
+ for (auto* param : func->params) {
+ if (auto* builtin =
+ ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
+ if (sem.Get(builtin)->Value() == core::BuiltinValue::kWorkgroupId) {
+ ProcessBuiltin(func, param);
+ made_changes = true;
+ }
+ }
+ }
+ }
+ }
+ if (!made_changes) {
+ return SkipTransform;
+ }
+
+ ctx.Clone();
+ return resolver::Resolve(b);
+ }
+
+ /// Process a workgroup_id builtin.
+ /// @param ep the entry point function
+ /// @param builtin the builtin parameter
+ void ProcessBuiltin(const ast::Function* ep, const ast::Parameter* builtin) {
+ // Record the name of the parameter for the entry point function.
+ func_to_param.Add(ep, ctx.Clone(builtin->name->symbol));
+
+ // The reader should only produce a single use of the parameter which assigns to a global.
+ const auto& users = sem.Get(builtin)->Users();
+ TINT_ASSERT_OR_RETURN(users.Length() == 1u);
+ auto* assign = users[0]->Stmt()->Declaration()->As<ast::AssignmentStatement>();
+ auto& stmts =
+ sem.Get(assign)->Parent()->Declaration()->As<ast::BlockStatement>()->statements;
+ auto* rhs = assign->rhs;
+ if (auto* bitcast = rhs->As<ast::BitcastExpression>()) {
+ // The RHS may be bitcast to a signed integer, so we capture that bitcast.
+ auto let = b.Symbols().New("tint_wgid_bitcast");
+ ctx.InsertBefore(stmts, assign, b.Decl(b.Let(let, ctx.Clone(bitcast))));
+ func_to_param.Replace(ep, let);
+ rhs = bitcast->expr;
+ }
+ TINT_ASSERT_OR_RETURN(assign && rhs == users[0]->Declaration());
+ auto* lhs = sem.GetVal(assign->lhs)->As<sem::VariableUser>();
+ TINT_ASSERT_OR_RETURN(lhs &&
+ lhs->Variable()->AddressSpace() == core::AddressSpace::kPrivate);
+
+ // Replace all references to the global variable with a function parameter.
+ for (auto* user : lhs->Variable()->Users()) {
+ if (user == lhs) {
+ // Skip the assignment, which will be removed.
+ continue;
+ }
+ auto param = GetParameter(user->Stmt()->Function()->Declaration(),
+ lhs->Variable()->Declaration()->type);
+ ctx.Replace(user->Declaration(), b.Expr(param));
+ }
+
+ // Remove the global variable and the assignment to it.
+ ctx.Remove(src.AST().GlobalDeclarations(), lhs->Variable()->Declaration());
+ ctx.Remove(stmts, assign);
+ }
+
+ /// Get the workgroup_id parameter for a function, creating it and updating callsites if needed.
+ /// @param func the function
+ /// @param type the type of the parameter
+ /// @returns the name of the parameter
+ Symbol GetParameter(const ast::Function* func, const ast::Type& type) {
+ return func_to_param.GetOrCreate(func, [&] {
+ // Append a new parameter to the function.
+ auto name = b.Symbols().New("tint_wgid");
+ ctx.InsertBack(func->params, b.Param(name, ctx.Clone(type)));
+
+ // Recursively update all callsites to pass the workgroup_id as an argument.
+ for (auto* callsite : sem.Get(func)->CallSites()) {
+ auto param = GetParameter(callsite->Stmt()->Function()->Declaration(), type);
+ ctx.InsertBack(callsite->Declaration()->args, b.Expr(param));
+ }
+
+ return name;
+ });
+ }
+};
+
+PassWorkgroupIdAsArgument::PassWorkgroupIdAsArgument() = default;
+
+PassWorkgroupIdAsArgument::~PassWorkgroupIdAsArgument() = default;
+
+ast::transform::Transform::ApplyResult PassWorkgroupIdAsArgument::Apply(
+ const Program& src,
+ const ast::transform::DataMap&,
+ ast::transform::DataMap&) const {
+ return State(src).Run();
+}
+
+} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h
new file mode 100644
index 0000000..e680352
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h
@@ -0,0 +1,59 @@
+// Copyright 2023 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_SPIRV_READER_AST_LOWER_PASS_WORKGROUP_ID_AS_ARGUMENT_H_
+#define SRC_TINT_LANG_SPIRV_READER_AST_LOWER_PASS_WORKGROUP_ID_AS_ARGUMENT_H_
+
+#include "src/tint/lang/wgsl/ast/transform/transform.h"
+
+namespace tint::spirv::reader {
+
+/// PassWorkgroupIdAsArgument is a transform that passes the workgroup_id builtin as an argument to
+/// functions that need it, instead of using a module-scope private variable. This allows the
+/// uniformity analysis to see that it is uniform, enabling shaders that use barriers in control
+/// flow guarded by this builtin.
+class PassWorkgroupIdAsArgument final
+ : public Castable<PassWorkgroupIdAsArgument, ast::transform::Transform> {
+ public:
+ /// Constructor
+ PassWorkgroupIdAsArgument();
+
+ /// Destructor
+ ~PassWorkgroupIdAsArgument() override;
+
+ /// @copydoc ast::transform::Transform::Apply
+ ApplyResult Apply(const Program& program,
+ const ast::transform::DataMap& inputs,
+ ast::transform::DataMap& outputs) const override;
+
+ private:
+ struct State;
+};
+
+} // namespace tint::spirv::reader
+
+#endif // SRC_TINT_LANG_SPIRV_READER_AST_LOWER_PASS_WORKGROUP_ID_AS_ARGUMENT_H_
diff --git a/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
new file mode 100644
index 0000000..3c1bbcd
--- /dev/null
+++ b/src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument_test.cc
@@ -0,0 +1,408 @@
+// Copyright 2023 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h"
+
+#include "src/tint/lang/wgsl/ast/transform/helper_test.h"
+
+namespace tint::spirv::reader {
+namespace {
+
+using PassWorkgroupIdAsArgumentTest = ast::transform::TransformTest;
+
+TEST_F(PassWorkgroupIdAsArgumentTest, Basic) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner() {
+ if (wgid.x == 0) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3u) {
+ if ((tint_wgid.x == 0)) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, MultipleUses) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner() {
+ if (wgid.x == 0) {
+ workgroupBarrier();
+ }
+ if (wgid.y == 0) {
+ workgroupBarrier();
+ }
+ if (wgid.z == 0) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3u) {
+ if ((tint_wgid.x == 0)) {
+ workgroupBarrier();
+ }
+ if ((tint_wgid.y == 0)) {
+ workgroupBarrier();
+ }
+ if ((tint_wgid.z == 0)) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, NestedCall) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner_2() {
+ if (wgid.x == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn inner_1() {
+ inner_2();
+}
+
+fn inner() {
+ inner_1();
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner_2(tint_wgid : vec3u) {
+ if ((tint_wgid.x == 0)) {
+ workgroupBarrier();
+ }
+}
+
+fn inner_1(tint_wgid_1 : vec3u) {
+ inner_2(tint_wgid_1);
+}
+
+fn inner(tint_wgid_2 : vec3u) {
+ inner_1(tint_wgid_2);
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, NestedCall_UsesAtEachLevel) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner_2() {
+ if (wgid.x == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn inner_1() {
+ inner_2();
+ if (wgid.y == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn inner() {
+ inner_1();
+ if (wgid.z == 0) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner_2(tint_wgid : vec3u) {
+ if ((tint_wgid.x == 0)) {
+ workgroupBarrier();
+ }
+}
+
+fn inner_1(tint_wgid_1 : vec3u) {
+ inner_2(tint_wgid_1);
+ if ((tint_wgid_1.y == 0)) {
+ workgroupBarrier();
+ }
+}
+
+fn inner(tint_wgid_2 : vec3u) {
+ inner_1(tint_wgid_2);
+ if ((tint_wgid_2.z == 0)) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, NestedCall_MultipleCallsites) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner_2() {
+ if (wgid.x == 0) {
+ workgroupBarrier();
+ }
+}
+
+fn inner_1() {
+ inner_2();
+ inner_2();
+ inner_2();
+}
+
+fn inner() {
+ inner_1();
+ inner_2();
+ inner_1();
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner_2(tint_wgid : vec3u) {
+ if ((tint_wgid.x == 0)) {
+ workgroupBarrier();
+ }
+}
+
+fn inner_1(tint_wgid_1 : vec3u) {
+ inner_2(tint_wgid_1);
+ inner_2(tint_wgid_1);
+ inner_2(tint_wgid_1);
+}
+
+fn inner(tint_wgid_2 : vec3u) {
+ inner_1(tint_wgid_2);
+ inner_2(tint_wgid_2);
+ inner_1(tint_wgid_2);
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, NestedCall_OtherParameters) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3u;
+
+fn inner_2(a : u32, b : u32) {
+ if (wgid.x + a == b) {
+ workgroupBarrier();
+ }
+}
+
+fn inner_1(a : u32) {
+ inner_2(a, 1);
+}
+
+fn inner() {
+ inner_1(2);
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = wgid_param;
+ inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner_2(a : u32, b : u32, tint_wgid : vec3u) {
+ if (((tint_wgid.x + a) == b)) {
+ workgroupBarrier();
+ }
+}
+
+fn inner_1(a : u32, tint_wgid_1 : vec3u) {
+ inner_2(a, 1, tint_wgid_1);
+}
+
+fn inner(tint_wgid_2 : vec3u) {
+ inner_1(2, tint_wgid_2);
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ inner(wgid_param);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(PassWorkgroupIdAsArgumentTest, BitcastToI32) {
+ auto* src = R"(
+enable chromium_disable_uniformity_analysis;
+
+var<private> wgid : vec3i;
+
+fn inner() {
+ if (wgid.x == 0i) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ wgid = bitcast<vec3i>(wgid_param);
+ inner();
+}
+)";
+
+ auto* expect = R"(
+enable chromium_disable_uniformity_analysis;
+
+fn inner(tint_wgid : vec3i) {
+ if ((tint_wgid.x == 0i)) {
+ workgroupBarrier();
+ }
+}
+
+@compute @workgroup_size(64)
+fn main(@builtin(workgroup_id) wgid_param : vec3u) {
+ let tint_wgid_bitcast = bitcast<vec3i>(wgid_param);
+ inner(tint_wgid_bitcast);
+}
+)";
+
+ auto got = Run<PassWorkgroupIdAsArgument>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+} // namespace
+} // namespace tint::spirv::reader
diff --git a/src/tint/lang/spirv/reader/ast_parser/parse.cc b/src/tint/lang/spirv/reader/ast_parser/parse.cc
index e3f8cf6..c1b6af7 100644
--- a/src/tint/lang/spirv/reader/ast_parser/parse.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/parse.cc
@@ -33,16 +33,49 @@
#include "src/tint/lang/spirv/reader/ast_lower/decompose_strided_array.h"
#include "src/tint/lang/spirv/reader/ast_lower/decompose_strided_matrix.h"
#include "src/tint/lang/spirv/reader/ast_lower/fold_trivial_lets.h"
+#include "src/tint/lang/spirv/reader/ast_lower/pass_workgroup_id_as_argument.h"
#include "src/tint/lang/spirv/reader/ast_parser/ast_parser.h"
#include "src/tint/lang/wgsl/ast/transform/manager.h"
#include "src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.h"
#include "src/tint/lang/wgsl/ast/transform/simplify_pointers.h"
#include "src/tint/lang/wgsl/ast/transform/unshadow.h"
+#include "src/tint/lang/wgsl/extension.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
namespace tint::spirv::reader::ast_parser {
+namespace {
+
+/// Trivial transform that removes the enable directive that disables the uniformity analysis.
+class ReenableUniformityAnalysis final
+ : public Castable<ReenableUniformityAnalysis, ast::transform::Transform> {
+ public:
+ ReenableUniformityAnalysis() {}
+ ~ReenableUniformityAnalysis() override {}
+
+ /// @copydoc ast::transform::Transform::Apply
+ ApplyResult Apply(const Program& src,
+ const ast::transform::DataMap&,
+ ast::transform::DataMap&) const override {
+ ProgramBuilder b;
+ program::CloneContext ctx = {&b, &src, /* auto_clone_symbols */ true};
+
+ // Remove the extension that disables the uniformity analysis.
+ for (auto* enable : src.AST().Enables()) {
+ if (enable->HasExtension(wgsl::Extension::kChromiumDisableUniformityAnalysis) &&
+ enable->extensions.Length() == 1u) {
+ ctx.Remove(src.AST().GlobalDeclarations(), enable);
+ }
+ }
+
+ ctx.Clone();
+ return resolver::Resolve(b);
+ }
+};
+
+} // namespace
+
Program Parse(const std::vector<uint32_t>& input, const Options& options) {
ASTParser parser(input);
bool parsed = parser.Parse();
@@ -60,13 +93,19 @@
builder.DiagnosticDirective(wgsl::DiagnosticSeverity::kOff, "derivative_uniformity");
}
+ // Disable the uniformity analysis temporarily.
+ // We will run transforms that attempt to change the AST to satisfy the analysis.
+ auto allowed_features = options.allowed_features;
+ allowed_features.extensions.insert(wgsl::Extension::kChromiumDisableUniformityAnalysis);
+ builder.Enable(wgsl::Extension::kChromiumDisableUniformityAnalysis);
+
// The SPIR-V parser can construct disjoint AST nodes, which is invalid for
// the Resolver. Clone the Program to clean these up.
Program program_with_disjoint_ast(std::move(builder));
ProgramBuilder output;
program::CloneContext(&output, &program_with_disjoint_ast, false).Clone();
- auto program = Program(resolver::Resolve(output, options.allowed_features));
+ auto program = Program(resolver::Resolve(output, allowed_features));
if (!program.IsValid()) {
return program;
}
@@ -76,11 +115,15 @@
manager.Add<ast::transform::Unshadow>();
manager.Add<ast::transform::SimplifyPointers>();
manager.Add<FoldTrivialLets>();
+ manager.Add<PassWorkgroupIdAsArgument>();
manager.Add<DecomposeStridedMatrix>();
manager.Add<DecomposeStridedArray>();
manager.Add<ast::transform::RemoveUnreachableStatements>();
manager.Add<Atomics>();
+ manager.Add<ReenableUniformityAnalysis>();
return manager.Run(program, {}, outputs);
}
} // namespace tint::spirv::reader::ast_parser
+
+TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::ReenableUniformityAnalysis);
diff --git a/src/tint/lang/spirv/reader/ast_parser/parser_test.cc b/src/tint/lang/spirv/reader/ast_parser/parser_test.cc
index 7375186..404e925 100644
--- a/src/tint/lang/spirv/reader/ast_parser/parser_test.cc
+++ b/src/tint/lang/spirv/reader/ast_parser/parser_test.cc
@@ -91,6 +91,43 @@
EXPECT_EQ(program.Diagnostics().count(), 0u) << errs;
}
+TEST_F(ParserTest, WorkgroupIdGuardingBarrier) {
+ auto spv = test::Assemble(R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %foo "foo" %wgid
+ OpExecutionMode %foo LocalSize 1 1 1
+ OpDecorate %wgid BuiltIn WorkgroupId
+ %uint = OpTypeInt 32 0
+ %vec3u = OpTypeVector %uint 3
+%_ptr_Input_vec3u = OpTypePointer Input %vec3u
+ %uint_0 = OpConstant %uint 0
+ %uint_2 = OpConstant %uint 2
+ %uint_8 = OpConstant %uint 8
+ %wgid = OpVariable %_ptr_Input_vec3u Input
+ %void = OpTypeVoid
+ %bool = OpTypeBool
+ %func_type = OpTypeFunction %void
+ %foo = OpFunction %void None %func_type
+ %foo_start = OpLabel
+ %wgid_value = OpLoad %vec3u %wgid
+ %wgid_x = OpCompositeExtract %uint %wgid_value 0
+ %condition = OpIEqual %bool %wgid_x %uint_0
+ OpSelectionMerge %merge None
+ OpBranchConditional %condition %true_branch %merge
+%true_branch = OpLabel
+ OpControlBarrier %uint_2 %uint_2 %uint_8
+ OpBranch %merge
+ %merge = OpLabel
+ OpReturn
+ OpFunctionEnd
+)");
+ auto program = Parse(spv, {});
+ auto errs = program.Diagnostics().str();
+ EXPECT_TRUE(program.IsValid()) << errs;
+ EXPECT_EQ(program.Diagnostics().count(), 0u) << errs;
+}
+
// TODO(dneto): uint32 vec, valid SPIR-V
// TODO(dneto): uint32 vec, invalid SPIR-V