[tint] Register subgroup matrix loads/stores with alias analysis

Fixed: 408245013
Change-Id: Ie6da75d3d31738e4224e2f8df4fd90b5d4fd4e55
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/243097
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/wgsl/resolver/alias_analysis_test.cc b/src/tint/lang/wgsl/resolver/alias_analysis_test.cc
index 047d9e2..936a19f 100644
--- a/src/tint/lang/wgsl/resolver/alias_analysis_test.cc
+++ b/src/tint/lang/wgsl/resolver/alias_analysis_test.cc
@@ -1397,5 +1397,214 @@
                                             ::testing::ValuesIn(kWorkgroupUniformLoadActions),
                                             ::testing::Values(true, false)));
 
+////////////////////////////////////////////////////////////////////////////////
+// Subgroup matrix builtins
+////////////////////////////////////////////////////////////////////////////////
+enum class SubgroupMatrixAction {
+    kLoad,
+    kStore,
+};
+
+constexpr std::array kSubgroupMatrixActions{
+    SubgroupMatrixAction::kLoad,
+    SubgroupMatrixAction::kStore,
+};
+
+class SubgroupMatrixTest : public ResolverTestWithParam<std::tuple<SubgroupMatrixAction,
+                                                                   SubgroupMatrixAction,
+                                                                   core::AddressSpace,
+                                                                   bool /* aliased */>> {
+  protected:
+    static constexpr std::string_view kPass = "<PASS>";
+
+    ast::Type Ptr() {
+        auto address_space = std::get<2>(GetParam());
+        if (address_space == storage) {
+            return ty.ptr<storage, array<f32, 1024>, read_write>();
+        } else {
+            return ty.ptr<array<f32, 1024>>(address_space);
+        }
+    }
+
+    void SetUp() override {
+        Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+        auto address_space = std::get<2>(GetParam());
+        if (address_space == storage) {
+            GlobalVar("v1", address_space, read_write, ty.array<f32, 1024>(),  //
+                      Binding(0_a), Group(0_a));
+            GlobalVar("v2", address_space, read_write, ty.array<f32, 1024>(),  //
+                      Binding(1_a), Group(0_a));
+        } else {
+            GlobalVar("v1", address_space, ty.array<f32, 1024>());
+            GlobalVar("v2", address_space, ty.array<f32, 1024>());
+        }
+    }
+
+    const ast::Statement* Do(SubgroupMatrixAction action, std::string_view ptr) {
+        switch (action) {
+            case SubgroupMatrixAction::kLoad:
+                return Assign(Phony(),
+                              Call(Ident(wgsl::BuiltinFn::kSubgroupMatrixLoad,
+                                         ty.subgroup_matrix(core::SubgroupMatrixKind::kResult,
+                                                            ty.f32(), 8, 8)),
+                                   ptr, 0_u, false, 8_u));
+            case SubgroupMatrixAction::kStore:
+                return CallStmt(Call(
+                    wgsl::BuiltinFn::kSubgroupMatrixStore, ptr, 0_u,
+                    Call(ty.subgroup_matrix(core::SubgroupMatrixKind::kResult, ty.f32(), 8, 8)),
+                    false, 8_u));
+        }
+        return nullptr;
+    }
+
+    bool IsWrite(SubgroupMatrixAction action) const {
+        return action == SubgroupMatrixAction::kStore;
+    }
+
+    bool ShouldPass() const {
+        auto [action_a, action_b, addrspace, aliased] = GetParam();
+        bool fail = aliased && (IsWrite(action_a) || IsWrite(action_b));
+        return !fail;
+    }
+
+    std::string Run() {
+        if (r()->Resolve()) {
+            return std::string(kPass);
+        }
+        return r()->error();
+    }
+};
+
+TEST_P(SubgroupMatrixTest, CallDirect) {
+    // var<ADDRSPACE> v1 : array<f32, 1024>;
+    // var<ADDRSPACE> v2 : array<f32, 1024>;
+    //
+    // fn caller() {
+    //    callee(&v1, aliased ? &v1 : &v2);
+    // }
+    //
+    // fn callee(p1 : PTR, p2 : PTR) {
+    //   <action-a>(p1);
+    //   <action-b>(p2);
+    // }
+    auto [action_a, action_b, addrspace, aliased] = GetParam();
+
+    Func("caller", tint::Empty, ty.void_(),
+         Vector{
+             CallStmt(Call("callee",  //
+                           AddressOf(Source{{12, 34}}, "v1"),
+                           AddressOf(Source{{56, 78}}, aliased ? "v1" : "v2"))),
+         });
+
+    Func("callee", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+         Vector{
+             Do(action_a, "p1"),
+             Do(action_b, "p2"),
+         });
+
+    EXPECT_EQ(Run(), ShouldPass() ? kPass : R"(56:78 error: invalid aliased pointer argument
+12:34 note: aliases with another argument passed here)");
+}
+
+TEST_P(SubgroupMatrixTest, CallThroughChain) {
+    // var<ADDRSPACE> v1 : array<f32, 1024>;
+    // var<ADDRSPACE> v2 : array<f32, 1024>;
+    //
+    // fn caller() {
+    //    callee(&v1, aliased ? &v1 : &v2);
+    // }
+    //
+    // fn f2(p1 : PTR, p2 : PTR) {
+    //    f1(p1, p2);
+    // }
+    //
+    // fn f1(p1 : PTR, p2 : PTR) {
+    //    callee(p1, p2);
+    // }
+    //
+    // fn callee(p1 : PTR, p2 : PTR) {
+    //   <action-a>(p1);
+    //   <action-b>(p2);
+    // }
+    auto [action_a, action_b, addrspace, aliased] = GetParam();
+
+    Func("caller", tint::Empty, ty.void_(),
+         Vector{
+             CallStmt(Call("callee",  //
+                           AddressOf(Source{{12, 34}}, "v1"),
+                           AddressOf(Source{{56, 78}}, aliased ? "v1" : "v2"))),
+         });
+
+    Func("f2", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+         Vector{
+             CallStmt(Call("f1", "p1", "p2")),
+         });
+
+    Func("f1", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+         Vector{
+             CallStmt(Call("callee", "p1", "p2")),
+         });
+
+    Func("callee", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+         Vector{
+             Do(action_a, "p1"),
+             Do(action_b, "p2"),
+         });
+
+    EXPECT_EQ(Run(), ShouldPass() ? kPass : R"(56:78 error: invalid aliased pointer argument
+12:34 note: aliases with another argument passed here)");
+}
+
+TEST_P(SubgroupMatrixTest, ReadWriteAcrossDifferentFunctions) {
+    // var<ADDRSPACE> v1 : array<f32, 1024>;
+    // var<ADDRSPACE> v2 : array<f32, 1024>;
+    //
+    // fn caller() {
+    //   f(&v1, aliased ? &v1 : &v2);
+    // }
+    //
+    // fn f(p1 : PTR, p2 : PTR) {
+    //    f1(p1);
+    //    f2(p2);
+    // }
+    //
+    // fn f1(p : PTR) {
+    //   <action-a>(p);
+    // }
+    //
+    // fn f2(p : PTR) {
+    //   <action-b>(p);
+    // }
+    auto [action_a, action_b, addrspace, aliased] = GetParam();
+
+    Func("caller", tint::Empty, ty.void_(),
+         Vector{
+             CallStmt(Call("f",  //
+                           AddressOf(Source{{12, 34}}, "v1"),
+                           AddressOf(Source{{56, 78}}, aliased ? "v1" : "v2"))),
+         });
+
+    Func("f", Vector{Param("p1", Ptr()), Param("p2", Ptr())}, ty.void_(),
+         Vector{
+             CallStmt(Call("f1", "p1")),
+             CallStmt(Call("f2", "p2")),
+         });
+
+    Func("f1", Vector{Param("p", Ptr())}, ty.void_(), Vector{Do(action_a, "p")});
+
+    Func("f2", Vector{Param("p", Ptr())}, ty.void_(), Vector{Do(action_b, "p")});
+
+    EXPECT_EQ(Run(), ShouldPass() ? kPass : R"(56:78 error: invalid aliased pointer argument
+12:34 note: aliases with another argument passed here)");
+}
+
+INSTANTIATE_TEST_SUITE_P(ResolverAliasAnalysisTest,
+                         SubgroupMatrixTest,
+                         ::testing::Combine(::testing::ValuesIn(kSubgroupMatrixActions),
+                                            ::testing::ValuesIn(kSubgroupMatrixActions),
+                                            ::testing::Values(core::AddressSpace::kWorkgroup,
+                                                              core::AddressSpace::kStorage),
+                                            ::testing::Values(true, false)));
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 9776ad7..50b7475 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -2450,6 +2450,13 @@
             RegisterStore(args[0]);
             break;
 
+        case wgsl::BuiltinFn::kSubgroupMatrixLoad:
+            RegisterLoad(args[0]);
+            break;
+        case wgsl::BuiltinFn::kSubgroupMatrixStore:
+            RegisterStore(args[0]);
+            break;
+
         default:
             break;
     }