[spir-v] Generate SubgroupMatrix info in SPIR-V backend.

Update the SPIR-V backend to generate the SubgroupMatrix configuration
information and place in the returned output.

Fixed: 459779965
Change-Id: Iac78d513d6df620248061c08be04a37ba0d144f7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/275095
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/api/common/BUILD.bazel b/src/tint/api/common/BUILD.bazel
index 6af3546..207e62ef 100644
--- a/src/tint/api/common/BUILD.bazel
+++ b/src/tint/api/common/BUILD.bazel
@@ -48,6 +48,7 @@
     "resource_binding_config.h",
     "resource_table_config.h",
     "resource_type.h",
+    "subgroup_matrix.h",
     "substitute_overrides_config.h",
     "vertex_pulling_config.h",
   ],
diff --git a/src/tint/api/common/BUILD.cmake b/src/tint/api/common/BUILD.cmake
index 4080fb9..bfa9ef6 100644
--- a/src/tint/api/common/BUILD.cmake
+++ b/src/tint/api/common/BUILD.cmake
@@ -45,6 +45,7 @@
   api/common/resource_binding_config.h
   api/common/resource_table_config.h
   api/common/resource_type.h
+  api/common/subgroup_matrix.h
   api/common/substitute_overrides_config.h
   api/common/vertex_pulling_config.cc
   api/common/vertex_pulling_config.h
diff --git a/src/tint/api/common/BUILD.gn b/src/tint/api/common/BUILD.gn
index 889e2db..774cad7 100644
--- a/src/tint/api/common/BUILD.gn
+++ b/src/tint/api/common/BUILD.gn
@@ -51,6 +51,7 @@
     "resource_binding_config.h",
     "resource_table_config.h",
     "resource_type.h",
+    "subgroup_matrix.h",
     "substitute_overrides_config.h",
     "vertex_pulling_config.cc",
     "vertex_pulling_config.h",
diff --git a/src/tint/api/common/subgroup_matrix.h b/src/tint/api/common/subgroup_matrix.h
new file mode 100644
index 0000000..b74a94a
--- /dev/null
+++ b/src/tint/api/common/subgroup_matrix.h
@@ -0,0 +1,103 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+//    list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+//    this list of conditions and the following disclaimer in the documentation
+//    and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+//    contributors may be used to endorse or promote products derived from
+//    this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_API_COMMON_SUBGROUP_MATRIX_H_
+#define SRC_TINT_API_COMMON_SUBGROUP_MATRIX_H_
+
+#include <unordered_set>
+
+#include "src/tint/utils/math/hash.h"
+
+namespace tint {
+
+enum class SubgroupMatrixType : uint8_t {
+    kF16 = 0,
+    kF32,
+    kU8,
+    kI8,
+    kU32,
+    kI32,
+};
+
+struct SubgroupMatrixMultiply {
+    uint32_t M;
+    uint32_t N;
+    uint32_t K;
+
+    SubgroupMatrixType input_type;
+    SubgroupMatrixType output_type;
+
+    bool operator==(const SubgroupMatrixMultiply& o) const {
+        return M == o.M && N == o.N && K == o.K && input_type == o.input_type &&
+               output_type == o.output_type;
+    }
+};
+
+enum class SubgroupMatrixDirection : uint8_t {
+    kInput,
+    kResult,
+};
+
+struct SubgroupMatrixConfig {
+    uint32_t columns;
+    uint32_t rows;
+    SubgroupMatrixType type;
+    SubgroupMatrixDirection direction;
+
+    bool operator==(const SubgroupMatrixConfig& o) const {
+        return columns == o.columns && rows == o.rows && type == o.type && direction == o.direction;
+    }
+};
+
+}  // namespace tint
+
+template <>
+class std::hash<tint::SubgroupMatrixMultiply> {
+  public:
+    inline std::size_t operator()(const tint::SubgroupMatrixMultiply& sm) const {
+        return tint::Hash(sm.M, sm.N, sm.K, sm.input_type, sm.output_type);
+    }
+};
+
+template <>
+class std::hash<tint::SubgroupMatrixConfig> {
+  public:
+    inline std::size_t operator()(const tint::SubgroupMatrixConfig sm) const {
+        return tint::Hash(sm.columns, sm.rows, sm.type, sm.direction);
+    }
+};
+
+namespace tint {
+
+struct SubgroupMatrixInfo {
+    std::unordered_set<SubgroupMatrixMultiply> multiplies;
+    std::unordered_set<SubgroupMatrixConfig> configs;
+};
+
+}  // namespace tint
+
+#endif  // SRC_TINT_API_COMMON_SUBGROUP_MATRIX_H_
diff --git a/src/tint/lang/core/ir/analysis/subgroup_matrix.h b/src/tint/lang/core/ir/analysis/subgroup_matrix.h
index d7d201f..6283ed2 100644
--- a/src/tint/lang/core/ir/analysis/subgroup_matrix.h
+++ b/src/tint/lang/core/ir/analysis/subgroup_matrix.h
@@ -28,10 +28,7 @@
 #ifndef SRC_TINT_LANG_CORE_IR_ANALYSIS_SUBGROUP_MATRIX_H_
 #define SRC_TINT_LANG_CORE_IR_ANALYSIS_SUBGROUP_MATRIX_H_
 
-#include <cstdint>
-#include <unordered_set>
-
-#include "src/tint/utils/math/hash.h"
+#include "src/tint/api/common/subgroup_matrix.h"
 
 // Forward declarations.
 namespace tint::core::ir {
@@ -40,71 +37,6 @@
 
 namespace tint::core::ir::analysis {
 
-enum class SubgroupMatrixType : uint8_t {
-    kF16 = 0,
-    kF32,
-    kU8,
-    kI8,
-    kU32,
-    kI32,
-};
-
-struct SubgroupMatrixMultiply {
-    uint32_t M;
-    uint32_t N;
-    uint32_t K;
-
-    SubgroupMatrixType input_type;
-    SubgroupMatrixType output_type;
-
-    bool operator==(const SubgroupMatrixMultiply& o) const {
-        return M == o.M && N == o.N && K == o.K && input_type == o.input_type &&
-               output_type == o.output_type;
-    }
-};
-
-enum class SubgroupMatrixDirection : uint8_t {
-    kInput,
-    kResult,
-};
-
-struct SubgroupMatrixConfig {
-    uint32_t columns;
-    uint32_t rows;
-    SubgroupMatrixType type;
-    SubgroupMatrixDirection direction;
-
-    bool operator==(const SubgroupMatrixConfig& o) const {
-        return columns == o.columns && rows == o.rows && type == o.type && direction == o.direction;
-    }
-};
-
-}  // namespace tint::core::ir::analysis
-
-template <>
-class std::hash<tint::core::ir::analysis::SubgroupMatrixMultiply> {
-  public:
-    inline std::size_t operator()(
-        const tint::core::ir::analysis::SubgroupMatrixMultiply& sm) const {
-        return tint::Hash(sm.M, sm.N, sm.K, sm.input_type, sm.output_type);
-    }
-};
-
-template <>
-class std::hash<tint::core::ir::analysis::SubgroupMatrixConfig> {
-  public:
-    inline std::size_t operator()(const tint::core::ir::analysis::SubgroupMatrixConfig sm) const {
-        return tint::Hash(sm.columns, sm.rows, sm.type, sm.direction);
-    }
-};
-
-namespace tint::core::ir::analysis {
-
-struct SubgroupMatrixInfo {
-    std::unordered_set<SubgroupMatrixMultiply> multiplies;
-    std::unordered_set<SubgroupMatrixConfig> configs;
-};
-
 /// Gathers information about the subgroup matrix configurations used in the module.
 ///
 /// This returns two fundamental types of information on the subgroup matrix uses.
diff --git a/src/tint/lang/msl/writer/raise/validate_subgroup_matrix.cc b/src/tint/lang/msl/writer/raise/validate_subgroup_matrix.cc
index 579514d..3c41263 100644
--- a/src/tint/lang/msl/writer/raise/validate_subgroup_matrix.cc
+++ b/src/tint/lang/msl/writer/raise/validate_subgroup_matrix.cc
@@ -57,8 +57,7 @@
                 break;
             }
 
-            if (i.type != core::ir::analysis::SubgroupMatrixType::kF32 &&
-                i.type != core::ir::analysis::SubgroupMatrixType::kF16) {
+            if (i.type != SubgroupMatrixType::kF32 && i.type != SubgroupMatrixType::kF16) {
                 diagnostics_.AddError(Source{})
                     << "subgroup_matrix requires a type of `f32` or `f16` for the selected device";
                 break;
diff --git a/src/tint/lang/spirv/writer/BUILD.bazel b/src/tint/lang/spirv/writer/BUILD.bazel
index 933dfbf..922f81c 100644
--- a/src/tint/lang/spirv/writer/BUILD.bazel
+++ b/src/tint/lang/spirv/writer/BUILD.bazel
@@ -50,6 +50,7 @@
     "//src/tint/lang/core/constant",
     "//src/tint/lang/core/intrinsic",
     "//src/tint/lang/core/ir",
+    "//src/tint/lang/core/ir/analysis",
     "//src/tint/lang/core/ir/transform",
     "//src/tint/lang/core/type",
     "//src/tint/utils",
diff --git a/src/tint/lang/spirv/writer/BUILD.cmake b/src/tint/lang/spirv/writer/BUILD.cmake
index 11fde65..86cc0e3 100644
--- a/src/tint/lang/spirv/writer/BUILD.cmake
+++ b/src/tint/lang/spirv/writer/BUILD.cmake
@@ -55,6 +55,7 @@
   tint_lang_core_constant
   tint_lang_core_intrinsic
   tint_lang_core_ir
+  tint_lang_core_ir_analysis
   tint_lang_core_ir_transform
   tint_lang_core_type
   tint_utils
diff --git a/src/tint/lang/spirv/writer/BUILD.gn b/src/tint/lang/spirv/writer/BUILD.gn
index e4050d5..9663aaf 100644
--- a/src/tint/lang/spirv/writer/BUILD.gn
+++ b/src/tint/lang/spirv/writer/BUILD.gn
@@ -55,6 +55,7 @@
       "${tint_src_dir}/lang/core/constant",
       "${tint_src_dir}/lang/core/intrinsic",
       "${tint_src_dir}/lang/core/ir",
+      "${tint_src_dir}/lang/core/ir/analysis",
       "${tint_src_dir}/lang/core/ir/transform",
       "${tint_src_dir}/lang/core/type",
       "${tint_src_dir}/utils",
diff --git a/src/tint/lang/spirv/writer/common/helper_test.h b/src/tint/lang/spirv/writer/common/helper_test.h
index 05949da..3d5d54a 100644
--- a/src/tint/lang/spirv/writer/common/helper_test.h
+++ b/src/tint/lang/spirv/writer/common/helper_test.h
@@ -102,6 +102,9 @@
     /// Workgroup info
     Output::WorkgroupInfo workgroup_info;
 
+    /// Subgroup Matrix Info
+    SubgroupMatrixInfo subgroup_matrix_info;
+
     /// @returns the error string from the validation
     std::string Error() const { return err_; }
 
@@ -128,6 +131,7 @@
             return false;
         }
         workgroup_info = result->workgroup_info;
+        subgroup_matrix_info = result->subgroup_matrix_info;
 
         return true;
     }
diff --git a/src/tint/lang/spirv/writer/common/output.h b/src/tint/lang/spirv/writer/common/output.h
index 20f1e2c..91605ec 100644
--- a/src/tint/lang/spirv/writer/common/output.h
+++ b/src/tint/lang/spirv/writer/common/output.h
@@ -33,6 +33,8 @@
 #include <cstdint>
 #include <vector>
 
+#include "src/tint/api/common/subgroup_matrix.h"
+
 namespace tint::spirv::writer {
 
 /// The output produced when generating SPIR-V.
@@ -68,6 +70,9 @@
 
     /// The workgroup size information, if the entry point was a compute shader
     WorkgroupInfo workgroup_info{};
+
+    /// The subgroup matrix information
+    SubgroupMatrixInfo subgroup_matrix_info{};
 };
 
 }  // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/type_test.cc b/src/tint/lang/spirv/writer/type_test.cc
index 3723321..85744ef 100644
--- a/src/tint/lang/spirv/writer/type_test.cc
+++ b/src/tint/lang/spirv/writer/type_test.cc
@@ -936,6 +936,23 @@
     EXPECT_INST("%22 = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_2 %uint_2 %uint_2");
 }
 
+TEST_F(SpirvWriterTest, SubgroupMatrix_ConfigReturned) {
+    auto* fn = b.ComputeFunction("main");
+    b.Append(fn->Block(), [&] {
+        b.Var("left", ty.ptr<function>(ty.subgroup_matrix_left(ty.f32(), 8, 4)));
+        b.Var("right", ty.ptr<function>(ty.subgroup_matrix_right(ty.u32(), 4, 8)));
+        b.Var("result", ty.ptr<function>(ty.subgroup_matrix_result(ty.i32(), 2, 2)));
+        b.Return(fn);
+    });
+
+    Options options{
+        .entry_point_name = "main",
+        .use_vulkan_memory_model = true,
+    };
+    ASSERT_TRUE(Generate(options)) << Error() << output_;
+    EXPECT_EQ(3u, subgroup_matrix_info.configs.size());
+}
+
 // Test that we can emit multiple types.
 // Includes types with the same opcode but different parameters.
 TEST_F(SpirvWriterTest, Type_Multiple) {
diff --git a/src/tint/lang/spirv/writer/writer.cc b/src/tint/lang/spirv/writer/writer.cc
index 59facd3..360ea37 100644
--- a/src/tint/lang/spirv/writer/writer.cc
+++ b/src/tint/lang/spirv/writer/writer.cc
@@ -31,6 +31,7 @@
 #include <utility>
 #include <vector>
 
+#include "src/tint/lang/core/ir/analysis/subgroup_matrix.h"
 #include "src/tint/lang/core/ir/core_builtin_call.h"
 #include "src/tint/lang/core/ir/referenced_module_vars.h"
 #include "src/tint/lang/core/ir/validator.h"
@@ -149,7 +150,14 @@
         return std::move(res.Failure());
     }
 
-    return Print(ir, options);
+    auto res = Print(ir, options);
+    if (res != Success) {
+        return res;
+    }
+
+    res->subgroup_matrix_info = core::ir::analysis::GatherSubgroupMatrixInfo(ir);
+
+    return res;
 }
 
 }  // namespace tint::spirv::writer