[tint] Validate subgroup matrix declarations
They are allowed in:
- let
- var<function> and var<private>
- function parameters
- function return types
- arrays
- structures
- ptr<function> and ptr<private>
They are not allowed in:
- const
- override
- any other var address space
They are considered fixed-footprint in order to allow them in arrays
and structures.
Bug: 348702031
Change-Id: If53a91eb38da5a2ef295e10019dbafdac89d0544
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/227175
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: James Price <jrprice@google.com>
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 00da12f..dced5b1 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -5045,6 +5045,16 @@
const_cast<core::type::Type*>(arr->ElemType()), usage);
}
+ // Subgroup matrix types can only be declared in the `function` and `private` address space, or
+ // in value declarations (the `undefined` address space).
+ if (ty->Is<core::type::SubgroupMatrix>() && address_space != core::AddressSpace::kUndefined &&
+ address_space != core::AddressSpace::kFunction &&
+ address_space != core::AddressSpace::kPrivate) {
+ AddError(usage) << "subgroup matrix types cannot be declared in the "
+ << style::Enum(address_space) << " address space";
+ return false;
+ }
+
if (core::IsHostShareable(address_space) && !validator_.IsHostShareable(ty)) {
AddError(usage) << "type " << style::Type(sem_.TypeNameOf(ty))
<< " cannot be used in address space " << style::Enum(address_space)
diff --git a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
index dc6c5fe..b8410e6c 100644
--- a/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
+++ b/src/tint/lang/wgsl/resolver/subgroup_matrix_test.cc
@@ -463,5 +463,182 @@
testing::HasSubstr(R"(error: no matching call to 'subgroupMatrixMultiply)"));
}
+TEST_F(ResolverSubgroupMatrixTest, Let_Valid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ Func("foo", Empty, ty.void_(),
+ Vector{
+ Decl(Let("result", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a),
+ Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a)))),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, FunctionVar_Valid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ Func("foo", Empty, ty.void_(),
+ Vector{
+ Decl(Var("result", function, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a))),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, PrivateVar_Valid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ GlobalVar("result", private_, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, StorageVar_Invalid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ GlobalVar("result", storage, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a), Group(0_a),
+ Binding(0_a));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(
+ r()->error(),
+ testing::HasSubstr(
+ R"(error: subgroup matrix types cannot be declared in the 'storage' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, UniformVar_Invalid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ GlobalVar("result", uniform, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a), Group(0_a),
+ Binding(0_a));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(
+ r()->error(),
+ testing::HasSubstr(
+ R"(error: subgroup matrix types cannot be declared in the 'uniform' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, WorkgroupVar_Invalid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ GlobalVar("result", workgroup, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(
+ r()->error(),
+ testing::HasSubstr(
+ R"(error: subgroup matrix types cannot be declared in the 'workgroup' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, PrivateVar_ArrayElement_Valid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ GlobalVar("result", private_, ty.array(ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 8_a));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, WorkgroupVar_ArrayElement_Invalid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ GlobalVar("result", workgroup, ty.array(ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a), 8_a));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(
+ r()->error(),
+ testing::HasSubstr(
+ R"(error: subgroup matrix types cannot be declared in the 'workgroup' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, PrivateVar_StructMember_Valid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+
+ auto* s = Structure("S", Vector{
+ Member("m", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a)),
+ });
+ GlobalVar("result", private_, ty.Of(s));
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, WorkgroupVar_StructMember_Invalid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+
+ auto* s = Structure("S", Vector{
+ Member("m", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a)),
+ });
+ GlobalVar("result", workgroup, ty.Of(s));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(
+ r()->error(),
+ testing::HasSubstr(
+ R"(error: subgroup matrix types cannot be declared in the 'workgroup' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, ConstVar_Invalid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ GlobalConst("result", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a),
+ Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(
+ r()->error(),
+ testing::HasSubstr(
+ R"(error: const initializer requires a const-expression, but expression is a runtime-expression)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, OverrideVar_Invalid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ Override("result", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(
+ r()->error(),
+ testing::HasSubstr(
+ R"(error: subgroup_matrix_result<f32, 8, 8> cannot be used as the type of a 'override')"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, FunctionParameter_Valid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ Func("foo",
+ Vector{
+ Param("result", ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a)),
+ },
+ ty.void_(), Empty);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, FunctionParameter_FunctionPointer_Valid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ Func("foo",
+ Vector{
+ Param("result", ty.ptr<function>(ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a))),
+ },
+ ty.void_(), Empty);
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverSubgroupMatrixTest, FunctionParameter_WorkgroupPointer_Invalid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ Func("foo",
+ Vector{
+ Param("result", ty.ptr<workgroup>(ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a))),
+ },
+ ty.void_(), Empty);
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_THAT(
+ r()->error(),
+ testing::HasSubstr(
+ R"(error: subgroup matrix types cannot be declared in the 'workgroup' address space)"));
+}
+
+TEST_F(ResolverSubgroupMatrixTest, ReturnType_Valid) {
+ Enable(wgsl::Extension::kChromiumExperimentalSubgroupMatrix);
+ Func("foo", Empty, ty("subgroup_matrix_result", ty.f32(), 8_a, 8_a),
+ Vector{
+ Return(Call(Ident("subgroup_matrix_result", ty.f32(), 8_a, 8_a))),
+ });
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
} // namespace
} // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index bc900a9..db1350c 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -224,6 +224,7 @@
[&](const core::type::Vector*) { return true; }, //
[&](const core::type::Matrix*) { return true; }, //
[&](const core::type::Atomic*) { return true; },
+ [&](const core::type::SubgroupMatrix*) { return true; },
[&](const sem::Array* arr) {
return !arr->Count()->Is<core::type::RuntimeArrayCount>() &&
IsFixedFootprint(arr->ElemType());