[tint] Register subgroup matrix loads/stores with alias analysis
Fixed: 408245013
Change-Id: Ie6da75d3d31738e4224e2f8df4fd90b5d4fd4e55
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/243097
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/wgsl/resolver/alias_analysis_test.cc b/src/tint/lang/wgsl/resolver/alias_analysis_test.cc
index 047d9e2..936a19f 100644
--- a/src/tint/lang/wgsl/resolver/alias_analysis_test.cc
+++ b/src/tint/lang/wgsl/resolver/alias_analysis_test.cc
@@ -1397,5 +1397,214 @@
::testing::ValuesIn(kWorkgroupUniformLoadActions),
::testing::Values(true, false)));
+////////////////////////////////////////////////////////////////////////////////
+// Subgroup matrix builtins
+////////////////////////////////////////////////////////////////////////////////
+enum class SubgroupMatrixAction {
+ kLoad,
+ kStore,
+};
+
+constexpr std::array kSubgroupMatrixActions{
+ SubgroupMatrixAction::kLoad,
+ SubgroupMatrixAction::kStore,
+};
+
+class SubgroupMatrixTest : public ResolverTestWithParam<std::tuple<SubgroupMatrixAction,
+ SubgroupMatrixAction,
+ core::AddressSpace,
+ bool /* aliased */>> {
+ protected:
+ static constexpr std::string_view kPass = "<PASS>";
+
+ ast::Type Ptr() {
+ auto address_space = std::get<2>(GetParam());
+ if (address_space == storage) {
+ return ty.ptr<storage, array<f32, 1024>, read_write>();
+ } else {
+ return ty.ptr<array<f32, 1024>>(address_space);
+ }
+ }
+
+ void SetUp() override {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ auto address_space = std::get<2>(GetParam());
+ if (address_space == storage) {
+ GlobalVar("v1", address_space, read_write, ty.array<f32, 1024>(), //
+ Binding(0_a), Group(0_a));
+ GlobalVar("v2", address_space, read_write, ty.array<f32, 1024>(), //
+ Binding(1_a), Group(0_a));
+ } else {
+ GlobalVar("v1", address_space, ty.array<f32, 1024>());
+ GlobalVar("v2", address_space, ty.array<f32, 1024>());
+ }
+ }
+
+ const ast::Statement* Do(SubgroupMatrixAction action, std::string_view ptr) {
+ switch (action) {
+ case SubgroupMatrixAction::kLoad:
+ return Assign(Phony(),
+ Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixLoad,
+ ty.subgroup_matrix(core::SubgroupMatrixKind::kResult,
+ ty.f32(), 8, 8)),
+ ptr, 0_u, false, 8_u));
+ case SubgroupMatrixAction::kStore:
+ return CallStmt(Call(
+ wgsl::BuiltinFn::kSubgroupMatrixStore, ptr, 0_u,
+ Call(ty.subgroup_matrix(core::SubgroupMatrixKind::kResult, ty.f32(), 8, 8)),
+ false, 8_u));
+ }
+ return nullptr;
+ }
+
+ bool IsWrite(SubgroupMatrixAction action) const {
+ return action == SubgroupMatrixAction::kStore;
+ }
+
+ bool ShouldPass() const {
+ auto [action_a, action_b, addrspace, aliased] = GetParam();
+ bool fail = aliased && (IsWrite(action_a) || IsWrite(action_b));
+ return !fail;
+ }
+
+ std::string Run() {
+ if (r()->Resolve()) {
+ return std::string(kPass);
+ }
+ return r()->error();
+ }
+};
+
+TEST_P(SubgroupMatrixTest, CallDirect) {
+ // var<ADDRSPACE> v1 : array<f32, 1024>;
+ // var<ADDRSPACE> v2 : array<f32, 1024>;
+ //
+ // fn caller() {
+ // callee(&v1, aliased ? &v1 : &v2);
+ // }
+ //
+ // fn callee(p1 : PTR, p2 : PTR) {
+ // <action-a>(p1);
+ // <action-b>(p2);
+ // }
+ auto [action_a, action_b, addrspace, aliased] = GetParam();
+
+ Func("caller", tint::Empty, ty.void_(),
+ Vector{
+ CallStmt(Call("callee", //
+ AddressOf(Source{{12, 34}}, "v1"),
+ AddressOf(Source{{56, 78}}, aliased ? "v1" : "v2"))),
+ });
+
+ Func("callee", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+ Vector{
+ Do(action_a, "p1"),
+ Do(action_b, "p2"),
+ });
+
+ EXPECT_EQ(Run(), ShouldPass() ? kPass : R"(56:78 error: invalid aliased pointer argument
+12:34 note: aliases with another argument passed here)");
+}
+
+TEST_P(SubgroupMatrixTest, CallThroughChain) {
+ // var<ADDRSPACE> v1 : array<f32, 1024>;
+ // var<ADDRSPACE> v2 : array<f32, 1024>;
+ //
+ // fn caller() {
+ // callee(&v1, aliased ? &v1 : &v2);
+ // }
+ //
+ // fn f2(p1 : PTR, p2 : PTR) {
+ // f1(p1, p2);
+ // }
+ //
+ // fn f1(p1 : PTR, p2 : PTR) {
+ // callee(p1, p2);
+ // }
+ //
+ // fn callee(p1 : PTR, p2 : PTR) {
+ // <action-a>(p1);
+ // <action-b>(p2);
+ // }
+ auto [action_a, action_b, addrspace, aliased] = GetParam();
+
+ Func("caller", tint::Empty, ty.void_(),
+ Vector{
+ CallStmt(Call("callee", //
+ AddressOf(Source{{12, 34}}, "v1"),
+ AddressOf(Source{{56, 78}}, aliased ? "v1" : "v2"))),
+ });
+
+ Func("f2", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+ Vector{
+ CallStmt(Call("f1", "p1", "p2")),
+ });
+
+ Func("f1", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+ Vector{
+ CallStmt(Call("callee", "p1", "p2")),
+ });
+
+ Func("callee", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+ Vector{
+ Do(action_a, "p1"),
+ Do(action_b, "p2"),
+ });
+
+ EXPECT_EQ(Run(), ShouldPass() ? kPass : R"(56:78 error: invalid aliased pointer argument
+12:34 note: aliases with another argument passed here)");
+}
+
+TEST_P(SubgroupMatrixTest, ReadWriteAcrossDifferentFunctions) {
+ // var<ADDRSPACE> v1 : array<f32, 1024>;
+ // var<ADDRSPACE> v2 : array<f32, 1024>;
+ //
+ // fn caller() {
+ // f(&v1, aliased ? &v1 : &v2);
+ // }
+ //
+ // fn f(p1 : PTR, p2 : PTR) {
+ // f1(p1);
+ // f2(p2);
+ // }
+ //
+ // fn f1(p : PTR) {
+ // <action-a>(p);
+ // }
+ //
+ // fn f2(p : PTR) {
+ // <action-b>(p);
+ // }
+ auto [action_a, action_b, addrspace, aliased] = GetParam();
+
+ Func("caller", tint::Empty, ty.void_(),
+ Vector{
+ CallStmt(Call("f", //
+ AddressOf(Source{{12, 34}}, "v1"),
+ AddressOf(Source{{56, 78}}, aliased ? "v1" : "v2"))),
+ });
+
+ Func("f", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+ Vector{
+ CallStmt(Call("f1", "p1")),
+ CallStmt(Call("f2", "p2")),
+ });
+
+ Func("f1", Vector{Param("p", Ptr())}, ty.void_(), Vector{Do(action_a, "p")});
+
+ Func("f2", Vector{Param("p", Ptr())}, ty.void_(), Vector{Do(action_b, "p")});
+
+ EXPECT_EQ(Run(), ShouldPass() ? kPass : R"(56:78 error: invalid aliased pointer argument
+12:34 note: aliases with another argument passed here)");
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverAliasAnalysisTest,
+ SubgroupMatrixTest,
+ ::testing::Combine(::testing::ValuesIn(kSubgroupMatrixActions),
+ ::testing::ValuesIn(kSubgroupMatrixActions),
+ ::testing::Values(core::AddressSpace::kWorkgroup,
+ core::AddressSpace::kStorage),
+ ::testing::Values(true, false)));
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 9776ad7..50b7475 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2450,6 +2450,13 @@
RegisterStore(args[0]);
break;
+ case wgsl::BuiltinFn::kSubgroupMatrixLoad:
+ RegisterLoad(args[0]);
+ break;
+ case wgsl::BuiltinFn::kSubgroupMatrixStore:
+ RegisterStore(args[0]);
+ break;
+
default:
break;
}