[spirv] Emit subgroup_matrix type
Require the Vulkan Memory Model, and enable the CooperativeMatrixKHR
extension and capability.
Bug: 348702031
Change-Id: I91fadfad4aa4e9422d4bb4415284ceb466401772
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/224655
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc
index 1d4bdb5..03135ff 100644
--- a/src/tint/lang/spirv/writer/printer/printer.cc
+++ b/src/tint/lang/spirv/writer/printer/printer.cc
@@ -89,6 +89,7 @@
#include "src/tint/lang/core/type/sampler.h"
#include "src/tint/lang/core/type/storage_texture.h"
#include "src/tint/lang/core/type/struct.h"
+#include "src/tint/lang/core/type/subgroup_matrix.h"
#include "src/tint/lang/core/type/texture.h"
#include "src/tint/lang/core/type/type.h"
#include "src/tint/lang/core/type/u32.h"
@@ -539,6 +540,37 @@
[&](const type::SampledImage* s) {
module_.PushType(spv::Op::OpTypeSampledImage, {id, Type(s->Image())});
}, //
+ [&](const core::type::SubgroupMatrix* sm) {
+ TINT_ASSERT(options_.use_vulkan_memory_model);
+ auto scope = Constant(ir_.constant_values.Get(u32(spv::Scope::Subgroup)));
+ auto cols = Constant(ir_.constant_values.Get(u32(sm->Columns())));
+ auto rows = Constant(ir_.constant_values.Get(u32(sm->Rows())));
+ spv::CooperativeMatrixUse use = spv::CooperativeMatrixUse::Max;
+ switch (sm->Kind()) {
+ case core::SubgroupMatrixKind::kLeft:
+ use = spv::CooperativeMatrixUse::MatrixAKHR;
+ break;
+ case core::SubgroupMatrixKind::kRight:
+ use = spv::CooperativeMatrixUse::MatrixBKHR;
+ break;
+ case core::SubgroupMatrixKind::kResult:
+ use = spv::CooperativeMatrixUse::MatrixAccumulatorKHR;
+ break;
+ case core::SubgroupMatrixKind::kUndefined:
+ TINT_UNREACHABLE();
+ }
+ module_.PushExtension("SPV_KHR_cooperative_matrix");
+ module_.PushCapability(SpvCapabilityCooperativeMatrixKHR);
+ module_.PushType(spv::Op::OpTypeCooperativeMatrixKHR,
+ {
+ id,
+ Type(sm->Type()),
+ scope,
+ rows,
+ cols,
+ Constant(ir_.constant_values.Get(u32(use))),
+ });
+ }, //
TINT_ICE_ON_NO_MATCH);
return id;
});
diff --git a/src/tint/lang/spirv/writer/type_test.cc b/src/tint/lang/spirv/writer/type_test.cc
index d2298a5..f5970b5 100644
--- a/src/tint/lang/spirv/writer/type_test.cc
+++ b/src/tint/lang/spirv/writer/type_test.cc
@@ -528,6 +528,23 @@
StorageTextureCase{" = OpTypeImage %float 2D 0 0 0 2 Rgba8", //
Dim::k2d, Format::kRgba8Unorm}));
+TEST_F(SpirvWriterTest, Type_SubgroupMatrix) {
+ b.Append(b.ir.root_block, [&] { //
+ b.Var("left", ty.ptr<private_>(ty.subgroup_matrix_left(ty.f32(), 8, 4)));
+ b.Var("right", ty.ptr<private_>(ty.subgroup_matrix_right(ty.u32(), 4, 8)));
+ b.Var("result", ty.ptr<private_>(ty.subgroup_matrix_result(ty.i32(), 2, 2)));
+ });
+
+ Options options;
+ options.use_vulkan_memory_model = true;
+ ASSERT_TRUE(Generate(options)) << Error() << output_;
+ EXPECT_INST("OpCapability CooperativeMatrixKHR");
+ EXPECT_INST("OpExtension \"SPV_KHR_cooperative_matrix\"");
+ EXPECT_INST("%3 = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_4 %uint_8 %uint_0");
+ EXPECT_INST("%13 = OpTypeCooperativeMatrixKHR %uint %uint_3 %uint_8 %uint_4 %uint_1");
+ EXPECT_INST("%18 = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_2 %uint_2 %uint_2");
+}
+
// Test that we can emit multiple types.
// Includes types with the same opcode but different parameters.
TEST_F(SpirvWriterTest, Type_Multiple) {