[tint] Validate subgroup matrix declarations

They are allowed in:
- let
- var<function> and var<private>
- function parameters
- function return types
- arrays
- structures
- ptr<function> and ptr<private>

They are not allowed in:
- const
- override
- any other var address space

They are considered fixed-footprint in order to allow them in arrays
and structures.

Bug: 348702031
Change-Id: If53a91eb38da5a2ef295e10019dbafdac89d0544
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/227175
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 00da12f..dced5b1 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -5045,6 +5045,16 @@
                                             const_cast<core::type::Type*>(arr->ElemType()), usage);
     }
 
+    // Subgroup matrix types can only be declared in the `function` and `private` address space, or
+    // in value declarations (the `undefined` address space).
+    if (ty->Is<core::type::SubgroupMatrix>() && address_space != core::AddressSpace::kUndefined &&
+        address_space != core::AddressSpace::kFunction &&
+        address_space != core::AddressSpace::kPrivate) {
+        AddError(usage) << "subgroup matrix types cannot be declared in the "
+                        << style::Enum(address_space) << " address space";
+        return false;
+    }
+
     if (core::IsHostShareable(address_space) && !validator_.IsHostShareable(ty)) {
         AddError(usage) << "type " << style::Type(sem_.TypeNameOf(ty))
                         << " cannot be used in address space " << style::Enum(address_space)
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
index dc6c5fe..b8410e6c 100644
--- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
+++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -463,5 +463,182 @@
                 testing::HasSubstr(R"(error: no matching call to 'subgroupMatrixMultiply)"));
 }
 
+TEST_F(ResolverSubgroupMatrixTest, Let_Valid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    Func("foo", Empty, ty.void_(),
+         Vector{
+             Decl(Let("result", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a),
+                      Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a)))),
+         });
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, FunctionVar_Valid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    Func("foo", Empty, ty.void_(),
+         Vector{
+             Decl(Var("result", function, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a))),
+         });
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, PrivateVar_Valid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    GlobalVar("result", private_, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a));
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, StorageVar_Invalid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    GlobalVar("result", storage, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a), Group(0_a),
+              Binding(0_a));
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_THAT(
+        r()->error(),
+        testing::HasSubstr(
+            R"(error: subgroup matrix types cannot be declared in the 'storage' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, UniformVar_Invalid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    GlobalVar("result", uniform, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a), Group(0_a),
+              Binding(0_a));
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_THAT(
+        r()->error(),
+        testing::HasSubstr(
+            R"(error: subgroup matrix types cannot be declared in the 'uniform' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, WorkgroupVar_Invalid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    GlobalVar("result", workgroup, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a));
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_THAT(
+        r()->error(),
+        testing::HasSubstr(
+            R"(error: subgroup matrix types cannot be declared in the 'workgroup' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, PrivateVar_ArrayElement_Valid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    GlobalVar("result", private_, ty.array(ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 8_a));
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, WorkgroupVar_ArrayElement_Invalid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    GlobalVar("result", workgroup, ty.array(ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 8_a));
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_THAT(
+        r()->error(),
+        testing::HasSubstr(
+            R"(error: subgroup matrix types cannot be declared in the 'workgroup' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, PrivateVar_StructMember_Valid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+
+    auto* s = Structure("S", Vector{
+                                 Member("m", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a)),
+                             });
+    GlobalVar("result", private_, ty.Of(s));
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, WorkgroupVar_StructMember_Invalid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+
+    auto* s = Structure("S", Vector{
+                                 Member("m", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a)),
+                             });
+    GlobalVar("result", workgroup, ty.Of(s));
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_THAT(
+        r()->error(),
+        testing::HasSubstr(
+            R"(error: subgroup matrix types cannot be declared in the 'workgroup' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, ConstVar_Invalid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    GlobalConst("result", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a),
+                Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a)));
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_THAT(
+        r()->error(),
+        testing::HasSubstr(
+            R"(error: const initializer requires a const-expression, but expression is a runtime-expression)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, OverrideVar_Invalid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    Override("result", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a));
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_THAT(
+        r()->error(),
+        testing::HasSubstr(
+            R"(error: subgroup_matrix_result<f32, 8, 8> cannot be used as the type of a 'override')"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, FunctionParameter_Valid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    Func("foo",
+         Vector{
+             Param("result", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a)),
+         },
+         ty.void_(), Empty);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, FunctionParameter_FunctionPointer_Valid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    Func("foo",
+         Vector{
+             Param("result", ty.ptr<function>(ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a))),
+         },
+         ty.void_(), Empty);
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, FunctionParameter_WorkgroupPointer_Invalid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    Func("foo",
+         Vector{
+             Param("result", ty.ptr<workgroup>(ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a))),
+         },
+         ty.void_(), Empty);
+
+    EXPECT_FALSE(r()->Resolve());
+    EXPECT_THAT(
+        r()->error(),
+        testing::HasSubstr(
+            R"(error: subgroup matrix types cannot be declared in the 'workgroup' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, ReturnType_Valid) {
+    Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+    Func("foo", Empty, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a),
+         Vector{
+             Return(Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a))),
+         });
+
+    EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
 }  // namespace
 }  // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index bc900a9..db1350c 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -224,6 +224,7 @@
         [&](const core::type::Vector*) { return true; },  //
         [&](const core::type::Matrix*) { return true; },  //
         [&](const core::type::Atomic*) { return true; },
+        [&](const core::type::SubgroupMatrix*) { return true; },
         [&](const sem::Array* arr) {
             return !arr->Count()->Is<core::type::RuntimeArrayCount>() &&
                    IsFixedFootprint(arr->ElemType());