[msl][ir] Emit `subgroup_matrix` types and zero values
This CL adds the emission of `subgroup_matrix_left`,
`subgroup_matrix_right` and `subgroup_matrix_result` types from the MSL
IR backend. The zero values are also emitted.
Bug: 348702031
Change-Id: Ia325ff881c25330eaecc2d68b4f79777a55b7041
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/202974
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 7595cf2..7b51791 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -1196,6 +1196,15 @@
}
}, //
[&](const msl::type::Level*) { out << "level"; }, //
+ [&](const core::type::SubgroupMatrix* sm) {
+ TINT_ASSERT((sm->Type()->IsAnyOf<core::type::F32, core::type::F16>()));
+ TINT_ASSERT(sm->Columns() == 8);
+ TINT_ASSERT(sm->Rows() == 8);
+
+ out << "simdgroup_";
+ EmitType(out, sm->Type());
+ out << sm->Columns() << "x" << sm->Rows();
+ },
TINT_ICE_ON_NO_MATCH);
}
@@ -1619,6 +1628,13 @@
},
[&](const core::type::Array*) { out << "{}"; }, //
[&](const core::type::Struct*) { out << "{}"; }, //
+ [&](const core::type::SubgroupMatrix* sm) {
+ out << "make_filled_simdgroup_matrix<";
+ EmitType(out, sm->Type());
+ out << ", " << sm->Columns() << ", " << sm->Rows() << ">(";
+ EmitZeroValue(out, sm->Type());
+ out << ")";
+ },
TINT_ICE_ON_NO_MATCH);
}
diff --git a/src/tint/lang/msl/writer/type_test.cc b/src/tint/lang/msl/writer/type_test.cc
index 7efe302..d58b464 100644
--- a/src/tint/lang/msl/writer/type_test.cc
+++ b/src/tint/lang/msl/writer/type_test.cc
@@ -1104,5 +1104,53 @@
MslStorageTextureData{core::type::TextureDimension::k3d,
"texture3d<float, access::write>"}));
+// Metal only supports f{16, 32} at (8x8). Bfloat is also supported but isn't in WGSL.
+TEST_F(MslWriterTest, EmitType_SubgroupMatrixLeft) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr(core::AddressSpace::kPrivate, ty.subgroup_matrix_left(ty.f32(), 8, 8)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate({}, validate::MslVersion::kMsl_2_3)) << err_ << output_.msl;
+ EXPECT_EQ(output_.msl, MetalHeader() + R"(
+void foo() {
+ thread simdgroup_float8x8 a = make_filled_simdgroup_matrix<float, 8, 8>(0.0f);
+}
+)");
+}
+
+// Metal only supports f{16, 32} at (8x8). Bfloat is also supported but isn't in WGSL.
+TEST_F(MslWriterTest, EmitType_SubgroupMatrixRight) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr(core::AddressSpace::kFunction, ty.subgroup_matrix_right(ty.f16(), 8, 8)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate({}, validate::MslVersion::kMsl_2_3)) << err_ << output_.msl;
+ EXPECT_EQ(output_.msl, MetalHeader() + R"(
+void foo() {
+ simdgroup_half8x8 a = make_filled_simdgroup_matrix<half, 8, 8>(0.0h);
+}
+)");
+}
+
+// Metal only supports f{16, 32} at (8x8). Bfloat is also supported but isn't in WGSL.
+TEST_F(MslWriterTest, EmitType_SubgroupMatrixResult) {
+ auto* func = b.Function("foo", ty.void_());
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr(core::AddressSpace::kPrivate, ty.subgroup_matrix_result(ty.f32(), 8, 8)));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate({}, validate::MslVersion::kMsl_2_3)) << err_ << output_.msl;
+ EXPECT_EQ(output_.msl, MetalHeader() + R"(
+void foo() {
+ thread simdgroup_float8x8 a = make_filled_simdgroup_matrix<float, 8, 8>(0.0f);
+}
+)");
+}
+
} // namespace
} // namespace tint::msl::writer