[spir-v] Generate SubgroupMatrix info in SPIR-V backend.
Update the SPIR-V backend to generate the SubgroupMatrix configuration
information and place in the returned output.
Fixed: 459779965
Change-Id: Iac78d513d6df620248061c08be04a37ba0d144f7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/275095
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/api/common/BUILD.bazel b/src/tint/api/common/BUILD.bazel
index 6af3546..207e62ef 100644
--- a/src/tint/api/common/BUILD.bazel
+++ b/src/tint/api/common/BUILD.bazel
@@ -48,6 +48,7 @@
"resource_binding_config.h",
"resource_table_config.h",
"resource_type.h",
+ "subgroup_matrix.h",
"substitute_overrides_config.h",
"vertex_pulling_config.h",
],
diff --git a/src/tint/api/common/BUILD.cmake b/src/tint/api/common/BUILD.cmake
index 4080fb9..bfa9ef6 100644
--- a/src/tint/api/common/BUILD.cmake
+++ b/src/tint/api/common/BUILD.cmake
@@ -45,6 +45,7 @@
api/common/resource_binding_config.h
api/common/resource_table_config.h
api/common/resource_type.h
+ api/common/subgroup_matrix.h
api/common/substitute_overrides_config.h
api/common/vertex_pulling_config.cc
api/common/vertex_pulling_config.h
diff --git a/src/tint/api/common/BUILD.gn b/src/tint/api/common/BUILD.gn
index 889e2db..774cad7 100644
--- a/src/tint/api/common/BUILD.gn
+++ b/src/tint/api/common/BUILD.gn
@@ -51,6 +51,7 @@
"resource_binding_config.h",
"resource_table_config.h",
"resource_type.h",
+ "subgroup_matrix.h",
"substitute_overrides_config.h",
"vertex_pulling_config.cc",
"vertex_pulling_config.h",
diff --git a/src/tint/api/common/subgroup_matrix.h b/src/tint/api/common/subgroup_matrix.h
new file mode 100644
index 0000000..b74a94a
--- /dev/null
+++ b/src/tint/api/common/subgroup_matrix.h
@@ -0,0 +1,103 @@
+// Copyright 2025 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_API_COMMON_SUBGROUP_MATRIX_H_
+#define SRC_TINT_API_COMMON_SUBGROUP_MATRIX_H_
+
+#include <unordered_set>
+
+#include "src/tint/utils/math/hash.h"
+
+namespace tint {
+
+enum class SubgroupMatrixType : uint8_t {
+ kF16 = 0,
+ kF32,
+ kU8,
+ kI8,
+ kU32,
+ kI32,
+};
+
+struct SubgroupMatrixMultiply {
+ uint32_t M;
+ uint32_t N;
+ uint32_t K;
+
+ SubgroupMatrixType input_type;
+ SubgroupMatrixType output_type;
+
+ bool operator==(const SubgroupMatrixMultiply& o) const {
+ return M == o.M && N == o.N && K == o.K && input_type == o.input_type &&
+ output_type == o.output_type;
+ }
+};
+
+enum class SubgroupMatrixDirection : uint8_t {
+ kInput,
+ kResult,
+};
+
+struct SubgroupMatrixConfig {
+ uint32_t columns;
+ uint32_t rows;
+ SubgroupMatrixType type;
+ SubgroupMatrixDirection direction;
+
+ bool operator==(const SubgroupMatrixConfig& o) const {
+ return columns == o.columns && rows == o.rows && type == o.type && direction == o.direction;
+ }
+};
+
+} // namespace tint
+
+template <>
+class std::hash<tint::SubgroupMatrixMultiply> {
+ public:
+ inline std::size_t operator()(const tint::SubgroupMatrixMultiply& sm) const {
+ return tint::Hash(sm.M, sm.N, sm.K, sm.input_type, sm.output_type);
+ }
+};
+
+template <>
+class std::hash<tint::SubgroupMatrixConfig> {
+ public:
+ inline std::size_t operator()(const tint::SubgroupMatrixConfig sm) const {
+ return tint::Hash(sm.columns, sm.rows, sm.type, sm.direction);
+ }
+};
+
+namespace tint {
+
+struct SubgroupMatrixInfo {
+ std::unordered_set<SubgroupMatrixMultiply> multiplies;
+ std::unordered_set<SubgroupMatrixConfig> configs;
+};
+
+} // namespace tint
+
+#endif // SRC_TINT_API_COMMON_SUBGROUP_MATRIX_H_
diff --git a/src/tint/lang/core/ir/analysis/subgroup_matrix.h b/src/tint/lang/core/ir/analysis/subgroup_matrix.h
index d7d201f..6283ed2 100644
--- a/src/tint/lang/core/ir/analysis/subgroup_matrix.h
+++ b/src/tint/lang/core/ir/analysis/subgroup_matrix.h
@@ -28,10 +28,7 @@
#ifndef SRC_TINT_LANG_CORE_IR_ANALYSIS_SUBGROUP_MATRIX_H_
#define SRC_TINT_LANG_CORE_IR_ANALYSIS_SUBGROUP_MATRIX_H_
-#include <cstdint>
-#include <unordered_set>
-
-#include "src/tint/utils/math/hash.h"
+#include "src/tint/api/common/subgroup_matrix.h"
// Forward declarations.
namespace tint::core::ir {
@@ -40,71 +37,6 @@
namespace tint::core::ir::analysis {
-enum class SubgroupMatrixType : uint8_t {
- kF16 = 0,
- kF32,
- kU8,
- kI8,
- kU32,
- kI32,
-};
-
-struct SubgroupMatrixMultiply {
- uint32_t M;
- uint32_t N;
- uint32_t K;
-
- SubgroupMatrixType input_type;
- SubgroupMatrixType output_type;
-
- bool operator==(const SubgroupMatrixMultiply& o) const {
- return M == o.M && N == o.N && K == o.K && input_type == o.input_type &&
- output_type == o.output_type;
- }
-};
-
-enum class SubgroupMatrixDirection : uint8_t {
- kInput,
- kResult,
-};
-
-struct SubgroupMatrixConfig {
- uint32_t columns;
- uint32_t rows;
- SubgroupMatrixType type;
- SubgroupMatrixDirection direction;
-
- bool operator==(const SubgroupMatrixConfig& o) const {
- return columns == o.columns && rows == o.rows && type == o.type && direction == o.direction;
- }
-};
-
-} // namespace tint::core::ir::analysis
-
-template <>
-class std::hash<tint::core::ir::analysis::SubgroupMatrixMultiply> {
- public:
- inline std::size_t operator()(
- const tint::core::ir::analysis::SubgroupMatrixMultiply& sm) const {
- return tint::Hash(sm.M, sm.N, sm.K, sm.input_type, sm.output_type);
- }
-};
-
-template <>
-class std::hash<tint::core::ir::analysis::SubgroupMatrixConfig> {
- public:
- inline std::size_t operator()(const tint::core::ir::analysis::SubgroupMatrixConfig sm) const {
- return tint::Hash(sm.columns, sm.rows, sm.type, sm.direction);
- }
-};
-
-namespace tint::core::ir::analysis {
-
-struct SubgroupMatrixInfo {
- std::unordered_set<SubgroupMatrixMultiply> multiplies;
- std::unordered_set<SubgroupMatrixConfig> configs;
-};
-
/// Gathers information about the subgroup matrix configurations used in the module.
///
/// This returns two fundamental types of information on the subgroup matrix uses.
diff --git a/src/tint/lang/msl/writer/raise/validate_subgroup_matrix.cc b/src/tint/lang/msl/writer/raise/validate_subgroup_matrix.cc
index 579514d..3c41263 100644
--- a/src/tint/lang/msl/writer/raise/validate_subgroup_matrix.cc
+++ b/src/tint/lang/msl/writer/raise/validate_subgroup_matrix.cc
@@ -57,8 +57,7 @@
break;
}
- if (i.type != core::ir::analysis::SubgroupMatrixType::kF32 &&
- i.type != core::ir::analysis::SubgroupMatrixType::kF16) {
+ if (i.type != SubgroupMatrixType::kF32 && i.type != SubgroupMatrixType::kF16) {
diagnostics_.AddError(Source{})
<< "subgroup_matrix requires a type of `f32` or `f16` for the selected device";
break;
diff --git a/src/tint/lang/spirv/writer/BUILD.bazel b/src/tint/lang/spirv/writer/BUILD.bazel
index 933dfbf..922f81c 100644
--- a/src/tint/lang/spirv/writer/BUILD.bazel
+++ b/src/tint/lang/spirv/writer/BUILD.bazel
@@ -50,6 +50,7 @@
"//src/tint/lang/core/constant",
"//src/tint/lang/core/intrinsic",
"//src/tint/lang/core/ir",
+ "//src/tint/lang/core/ir/analysis",
"//src/tint/lang/core/ir/transform",
"//src/tint/lang/core/type",
"//src/tint/utils",
diff --git a/src/tint/lang/spirv/writer/BUILD.cmake b/src/tint/lang/spirv/writer/BUILD.cmake
index 11fde65..86cc0e3 100644
--- a/src/tint/lang/spirv/writer/BUILD.cmake
+++ b/src/tint/lang/spirv/writer/BUILD.cmake
@@ -55,6 +55,7 @@
tint_lang_core_constant
tint_lang_core_intrinsic
tint_lang_core_ir
+ tint_lang_core_ir_analysis
tint_lang_core_ir_transform
tint_lang_core_type
tint_utils
diff --git a/src/tint/lang/spirv/writer/BUILD.gn b/src/tint/lang/spirv/writer/BUILD.gn
index e4050d5..9663aaf 100644
--- a/src/tint/lang/spirv/writer/BUILD.gn
+++ b/src/tint/lang/spirv/writer/BUILD.gn
@@ -55,6 +55,7 @@
"${tint_src_dir}/lang/core/constant",
"${tint_src_dir}/lang/core/intrinsic",
"${tint_src_dir}/lang/core/ir",
+ "${tint_src_dir}/lang/core/ir/analysis",
"${tint_src_dir}/lang/core/ir/transform",
"${tint_src_dir}/lang/core/type",
"${tint_src_dir}/utils",
diff --git a/src/tint/lang/spirv/writer/common/helper_test.h b/src/tint/lang/spirv/writer/common/helper_test.h
index 05949da..3d5d54a 100644
--- a/src/tint/lang/spirv/writer/common/helper_test.h
+++ b/src/tint/lang/spirv/writer/common/helper_test.h
@@ -102,6 +102,9 @@
/// Workgroup info
Output::WorkgroupInfo workgroup_info;
+ /// Subgroup Matrix Info
+ SubgroupMatrixInfo subgroup_matrix_info;
+
/// @returns the error string from the validation
std::string Error() const { return err_; }
@@ -128,6 +131,7 @@
return false;
}
workgroup_info = result->workgroup_info;
+ subgroup_matrix_info = result->subgroup_matrix_info;
return true;
}
diff --git a/src/tint/lang/spirv/writer/common/output.h b/src/tint/lang/spirv/writer/common/output.h
index 20f1e2c..91605ec 100644
--- a/src/tint/lang/spirv/writer/common/output.h
+++ b/src/tint/lang/spirv/writer/common/output.h
@@ -33,6 +33,8 @@
#include <cstdint>
#include <vector>
+#include "src/tint/api/common/subgroup_matrix.h"
+
namespace tint::spirv::writer {
/// The output produced when generating SPIR-V.
@@ -68,6 +70,9 @@
/// The workgroup size information, if the entry point was a compute shader
WorkgroupInfo workgroup_info{};
+
+ /// The subgroup matrix information
+ SubgroupMatrixInfo subgroup_matrix_info{};
};
} // namespace tint::spirv::writer
diff --git a/src/tint/lang/spirv/writer/type_test.cc b/src/tint/lang/spirv/writer/type_test.cc
index 3723321..85744ef 100644
--- a/src/tint/lang/spirv/writer/type_test.cc
+++ b/src/tint/lang/spirv/writer/type_test.cc
@@ -936,6 +936,23 @@
EXPECT_INST("%22 = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_2 %uint_2 %uint_2");
}
+TEST_F(SpirvWriterTest, SubgroupMatrix_ConfigReturned) {
+ auto* fn = b.ComputeFunction("main");
+ b.Append(fn->Block(), [&] {
+ b.Var("left", ty.ptr<function>(ty.subgroup_matrix_left(ty.f32(), 8, 4)));
+ b.Var("right", ty.ptr<function>(ty.subgroup_matrix_right(ty.u32(), 4, 8)));
+ b.Var("result", ty.ptr<function>(ty.subgroup_matrix_result(ty.i32(), 2, 2)));
+ b.Return(fn);
+ });
+
+ Options options{
+ .entry_point_name = "main",
+ .use_vulkan_memory_model = true,
+ };
+ ASSERT_TRUE(Generate(options)) << Error() << output_;
+ EXPECT_EQ(3u, subgroup_matrix_info.configs.size());
+}
+
// Test that we can emit multiple types.
// Includes types with the same opcode but different parameters.
TEST_F(SpirvWriterTest, Type_Multiple) {
diff --git a/src/tint/lang/spirv/writer/writer.cc b/src/tint/lang/spirv/writer/writer.cc
index 59facd3..360ea37 100644
--- a/src/tint/lang/spirv/writer/writer.cc
+++ b/src/tint/lang/spirv/writer/writer.cc
@@ -31,6 +31,7 @@
#include <utility>
#include <vector>
+#include "src/tint/lang/core/ir/analysis/subgroup_matrix.h"
#include "src/tint/lang/core/ir/core_builtin_call.h"
#include "src/tint/lang/core/ir/referenced_module_vars.h"
#include "src/tint/lang/core/ir/validator.h"
@@ -149,7 +150,14 @@
return std::move(res.Failure());
}
- return Print(ir, options);
+ auto res = Print(ir, options);
+ if (res != Success) {
+ return res;
+ }
+
+ res->subgroup_matrix_info = core::ir::analysis::GatherSubgroupMatrixInfo(ir);
+
+ return res;
}
} // namespace tint::spirv::writer