[spirv] Emit subgroup_matrix type

Require the Vulkan Memory Model, and enable the CooperativeMatrixKHR
extension and capability.

Bug: 348702031
Change-Id: I91fadfad4aa4e9422d4bb4415284ceb466401772
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/224655
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 1d4bdb5..03135ff 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -89,6 +89,7 @@
 #include "src/tint/lang/core/type/sampler.h"
 #include "src/tint/lang/core/type/storage_texture.h"
 #include "src/tint/lang/core/type/struct.h"
+#include "src/tint/lang/core/type/subgroup_matrix.h"
 #include "src/tint/lang/core/type/texture.h"
 #include "src/tint/lang/core/type/type.h"
 #include "src/tint/lang/core/type/u32.h"
@@ -539,6 +540,37 @@
                 [&](const type::SampledImage* s) {
                     module_.PushType(spv::Op::OpTypeSampledImage, {id, Type(s->Image())});
                 },  //
+                [&](const core::type::SubgroupMatrix* sm) {
+                    TINT_ASSERT(options_.use_vulkan_memory_model);
+                    auto scope = Constant(ir_.constant_values.Get(u32(spv::Scope::Subgroup)));
+                    auto cols = Constant(ir_.constant_values.Get(u32(sm->Columns())));
+                    auto rows = Constant(ir_.constant_values.Get(u32(sm->Rows())));
+                    spv::CooperativeMatrixUse use = spv::CooperativeMatrixUse::Max;
+                    switch (sm->Kind()) {
+                        case core::SubgroupMatrixKind::kLeft:
+                            use = spv::CooperativeMatrixUse::MatrixAKHR;
+                            break;
+                        case core::SubgroupMatrixKind::kRight:
+                            use = spv::CooperativeMatrixUse::MatrixBKHR;
+                            break;
+                        case core::SubgroupMatrixKind::kResult:
+                            use = spv::CooperativeMatrixUse::MatrixAccumulatorKHR;
+                            break;
+                        case core::SubgroupMatrixKind::kUndefined:
+                            TINT_UNREACHABLE();
+                    }
+                    module_.PushExtension("SPV_KHR_cooperative_matrix");
+                    module_.PushCapability(SpvCapabilityCooperativeMatrixKHR);
+                    module_.PushType(spv::Op::OpTypeCooperativeMatrixKHR,
+                                     {
+                                         id,
+                                         Type(sm->Type()),
+                                         scope,
+                                         rows,
+                                         cols,
+                                         Constant(ir_.constant_values.Get(u32(use))),
+                                     });
+                },  //
                 TINT_ICE_ON_NO_MATCH);
             return id;
         });
diff --git a/src/tint/lang/spirv/writer/type_test.cc b/src/tint/lang/spirv/writer/type_test.cc
index d2298a5..f5970b5 100644
--- a/src/tint/lang/spirv/writer/type_test.cc
+++ b/src/tint/lang/spirv/writer/type_test.cc
@@ -528,6 +528,23 @@
                              StorageTextureCase{" = OpTypeImage %float 2D 0 0 0 2 Rgba8",  //
                                                 Dim::k2d, Format::kRgba8Unorm}));
 
+TEST_F(SpirvWriterTest, Type_SubgroupMatrix) {
+    b.Append(b.ir.root_block, [&] {  //
+        b.Var("left", ty.ptr<private_>(ty.subgroup_matrix_left(ty.f32(), 8, 4)));
+        b.Var("right", ty.ptr<private_>(ty.subgroup_matrix_right(ty.u32(), 4, 8)));
+        b.Var("result", ty.ptr<private_>(ty.subgroup_matrix_result(ty.i32(), 2, 2)));
+    });
+
+    Options options;
+    options.use_vulkan_memory_model = true;
+    ASSERT_TRUE(Generate(options)) << Error() << output_;
+    EXPECT_INST("OpCapability CooperativeMatrixKHR");
+    EXPECT_INST("OpExtension \"SPV_KHR_cooperative_matrix\"");
+    EXPECT_INST("%3 = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_4 %uint_8 %uint_0");
+    EXPECT_INST("%13 = OpTypeCooperativeMatrixKHR %uint %uint_3 %uint_8 %uint_4 %uint_1");
+    EXPECT_INST("%18 = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_2 %uint_2 %uint_2");
+}
+
 // Test that we can emit multiple types.
 // Includes types with the same opcode but different parameters.
 TEST_F(SpirvWriterTest, Type_Multiple) {