Add subgroup matrix type.
This CL adds a `SubgroupMatrix` type which allows creating a `Left`,
`Right` or `Result` matrix.
Bug: 348702031
Change-Id: Ic2b2998b1aec5a15cea61b9a37e040b3d7752554
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/202394
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/core/type/BUILD.bazel b/src/tint/lang/core/type/BUILD.bazel
index 068b669..56cddca 100644
--- a/src/tint/lang/core/type/BUILD.bazel
+++ b/src/tint/lang/core/type/BUILD.bazel
@@ -70,6 +70,7 @@
"scalar.cc",
"storage_texture.cc",
"struct.cc",
+ "subgroup_matrix.cc",
"texture.cc",
"texture_dimension.cc",
"type.cc",
@@ -112,6 +113,7 @@
"scalar.h",
"storage_texture.h",
"struct.h",
+ "subgroup_matrix.h",
"texture.h",
"texture_dimension.h",
"type.h",
@@ -166,6 +168,7 @@
"sampler_test.cc",
"storage_texture_test.cc",
"struct_test.cc",
+ "subgroup_matrix_test.cc",
"texture_test.cc",
"type_test.cc",
"u32_test.cc",
diff --git a/src/tint/lang/core/type/BUILD.cmake b/src/tint/lang/core/type/BUILD.cmake
index 50a64de..9981a56 100644
--- a/src/tint/lang/core/type/BUILD.cmake
+++ b/src/tint/lang/core/type/BUILD.cmake
@@ -102,6 +102,8 @@
lang/core/type/storage_texture.h
lang/core/type/struct.cc
lang/core/type/struct.h
+ lang/core/type/subgroup_matrix.cc
+ lang/core/type/subgroup_matrix.h
lang/core/type/texture.cc
lang/core/type/texture.h
lang/core/type/texture_dimension.cc
@@ -164,6 +166,7 @@
lang/core/type/sampler_test.cc
lang/core/type/storage_texture_test.cc
lang/core/type/struct_test.cc
+ lang/core/type/subgroup_matrix_test.cc
lang/core/type/texture_test.cc
lang/core/type/type_test.cc
lang/core/type/u32_test.cc
diff --git a/src/tint/lang/core/type/BUILD.gn b/src/tint/lang/core/type/BUILD.gn
index 4cfe152..988fdf5 100644
--- a/src/tint/lang/core/type/BUILD.gn
+++ b/src/tint/lang/core/type/BUILD.gn
@@ -107,6 +107,8 @@
"storage_texture.h",
"struct.cc",
"struct.h",
+ "subgroup_matrix.cc",
+ "subgroup_matrix.h",
"texture.cc",
"texture.h",
"texture_dimension.cc",
@@ -166,6 +168,7 @@
"sampler_test.cc",
"storage_texture_test.cc",
"struct_test.cc",
+ "subgroup_matrix_test.cc",
"texture_test.cc",
"type_test.cc",
"u32_test.cc",
diff --git a/src/tint/lang/core/type/manager.cc b/src/tint/lang/core/type/manager.cc
index a65f83d..63e6071 100644
--- a/src/tint/lang/core/type/manager.cc
+++ b/src/tint/lang/core/type/manager.cc
@@ -180,6 +180,13 @@
return mat(inner, 4, 4);
}
+const core::type::SubgroupMatrix* Manager::subgroup_matrix(enum SubgroupMatrix::Kind kind,
+ const core::type::Type* inner,
+ uint32_t rows,
+ uint32_t cols) {
+ return Get<core::type::SubgroupMatrix>(kind, inner, rows, cols);
+}
+
const core::type::Array* Manager::array(const core::type::Type* elem_ty,
uint32_t count,
uint32_t stride /* = 0*/) {
diff --git a/src/tint/lang/core/type/manager.h b/src/tint/lang/core/type/manager.h
index a8c42e4..64b7db3 100644
--- a/src/tint/lang/core/type/manager.h
+++ b/src/tint/lang/core/type/manager.h
@@ -38,6 +38,7 @@
#include "src/tint/lang/core/type/external_texture.h"
#include "src/tint/lang/core/type/sampler.h"
#include "src/tint/lang/core/type/struct.h"
+#include "src/tint/lang/core/type/subgroup_matrix.h"
#include "src/tint/lang/core/type/type.h"
#include "src/tint/lang/core/type/unique_node.h"
#include "src/tint/utils/containers/unique_allocator.h"
@@ -425,6 +426,56 @@
return mat(Get<T>(), C, R);
}
+ /// @param kind the subgroup matrix kind
+ /// @param inner the inner type
+ /// @param rows the number of rows
+ /// @param cols the number of columns
+ /// @returns the subgroup_matrix type
+ const core::type::SubgroupMatrix* subgroup_matrix(enum SubgroupMatrix::Kind kind,
+ const core::type::Type* inner,
+ uint32_t rows,
+ uint32_t cols);
+
+ /// @param inner the inner type
+ /// @param rows the number of rows
+ /// @param cols the number of columns
+ /// @returns the subgroup_matrix type
+ const core::type::SubgroupMatrix* subgroup_matrix_left(const core::type::Type* inner,
+ uint32_t rows,
+ uint32_t cols) {
+ return subgroup_matrix(SubgroupMatrix::Kind::kLeft, inner, rows, cols);
+ }
+
+ /// @param inner the inner type
+ /// @param rows the number of rows
+ /// @param cols the number of columns
+ /// @returns the subgroup_matrix type
+ const core::type::SubgroupMatrix* subgroup_matrix_right(const core::type::Type* inner,
+ uint32_t rows,
+ uint32_t cols) {
+ return subgroup_matrix(SubgroupMatrix::Kind::kRight, inner, rows, cols);
+ }
+
+ /// @param inner the inner type
+ /// @param rows the number of rows
+ /// @param cols the number of columns
+ /// @returns the subgroup_matrix type
+ const core::type::SubgroupMatrix* subgroup_matrix_result(const core::type::Type* inner,
+ uint32_t rows,
+ uint32_t cols) {
+ return subgroup_matrix(SubgroupMatrix::Kind::kResult, inner, rows, cols);
+ }
+
+ /// @tparam K the kind of the matrix
+ /// @tparam T the element type
+ /// @tparam R the number of rows in the matrix
+ /// @tparam C the number of columns in the matrix
+ /// @returns a matrix with the given number of columns and rows
+ template <enum SubgroupMatrix::Kind K, typename T, uint32_t R, uint32_t C>
+ const core::type::SubgroupMatrix* subgroup_matrix() {
+ return subgroup_matrix(K, Get<T>(), C, R);
+ }
+
/// @param elem_ty the array element type
/// @param count the array element count
/// @param stride the optional array element stride
diff --git a/src/tint/lang/core/type/subgroup_matrix.cc b/src/tint/lang/core/type/subgroup_matrix.cc
new file mode 100644
index 0000000..a836e98
--- /dev/null
+++ b/src/tint/lang/core/type/subgroup_matrix.cc
@@ -0,0 +1,88 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/core/type/subgroup_matrix.h"
+
+#include "src/tint/lang/core/type/manager.h"
+
+TINT_INSTANTIATE_TYPEINFO(tint::core::type::SubgroupMatrix);
+
+namespace tint::core::type {
+
+SubgroupMatrix::SubgroupMatrix(enum Kind kind,
+ const type::Type* subtype,
+ uint32_t rows,
+ uint32_t columns)
+ : Base(Hash(tint::TypeCode::Of<SubgroupMatrix>().bits, kind, rows, columns, subtype),
+ core::type::Flags{
+ Flag::kConstructable,
+ Flag::kCreationFixedFootprint,
+ Flag::kFixedFootprint,
+ }),
+ kind_(kind),
+ subtype_(subtype),
+ rows_(rows),
+ columns_(columns) {}
+
+SubgroupMatrix::~SubgroupMatrix() = default;
+
+bool SubgroupMatrix::Equals(const UniqueNode& other) const {
+ if (auto* v = other.As<SubgroupMatrix>()) {
+ return v->kind_ == kind_ && v->rows_ == rows_ && v->columns_ == columns_ &&
+ v->subtype_ == subtype_;
+ }
+ return false;
+}
+
+uint32_t SubgroupMatrix::Align() const {
+ return subtype_->Align();
+}
+
+std::string SubgroupMatrix::FriendlyName() const {
+ StringStream out;
+ out << "subgroup_matrix_";
+ switch (kind_) {
+ case Kind::kLeft:
+ out << "left";
+ break;
+ case Kind::kRight:
+ out << "right";
+ break;
+ case Kind::kResult:
+ out << "result";
+ break;
+ }
+ out << "<" << subtype_->FriendlyName() << ", " << rows_ << ", " << columns_ << ">";
+ return out.str();
+}
+
+SubgroupMatrix* SubgroupMatrix::Clone(CloneContext& ctx) const {
+ auto* ty = subtype_->Clone(ctx);
+ return ctx.dst.mgr->Get<SubgroupMatrix>(kind_, ty, rows_, columns_);
+}
+
+} // namespace tint::core::type
diff --git a/src/tint/lang/core/type/subgroup_matrix.h b/src/tint/lang/core/type/subgroup_matrix.h
new file mode 100644
index 0000000..612c53f
--- /dev/null
+++ b/src/tint/lang/core/type/subgroup_matrix.h
@@ -0,0 +1,95 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef SRC_TINT_LANG_CORE_TYPE_SUBGROUP_MATRIX_H_
+#define SRC_TINT_LANG_CORE_TYPE_SUBGROUP_MATRIX_H_
+
+#include <string>
+
+#include "src/tint/lang/core/type/clone_context.h"
+#include "src/tint/lang/core/type/type.h"
+
+namespace tint::core::type {
+
+/// A subgroup_matrix type
+class SubgroupMatrix : public Castable<SubgroupMatrix, Type> {
+ public:
+ /// The kind of the subgroup matrix
+ enum class Kind : uint8_t {
+ /// A left matrix
+ kLeft,
+ /// A right matrix
+ kRight,
+ /// A result matrix
+ kResult,
+ };
+
+ /// Constructor
+ /// @param kind the kind of the matrix
+ /// @param subtype the inner type of the matrix
+ /// @param rows the number of rows in the matrix
+ /// @param columns the number of columns in the matrix
+ SubgroupMatrix(Kind kind, const Type* subtype, uint32_t rows, uint32_t columns);
+
+ /// Destructor
+ ~SubgroupMatrix() override;
+
+ /// @param other the other node to compare against
+ /// @returns true if the this type is equal to @p other
+ bool Equals(const UniqueNode& other) const override;
+
+ /// @returns the kind of the matrix
+ SubgroupMatrix::Kind Kind() const { return kind_; }
+ /// @returns the type of the matrix
+ const type::Type* Type() const { return subtype_; }
+ /// @returns the number of rows in the matrix
+ uint32_t Rows() const { return rows_; }
+ /// @returns the number of columns in the matrix
+ uint32_t Columns() const { return columns_; }
+
+ /// @returns the alignment in bytes of the type. This may include tail
+ /// padding.
+ uint32_t Align() const override;
+
+ /// @returns the name for this type that closely resembles how it would be
+ /// declared in WGSL.
+ std::string FriendlyName() const override;
+
+ /// @param ctx the clone context
+ /// @returns a clone of this type
+ SubgroupMatrix* Clone(CloneContext& ctx) const override;
+
+ private:
+ const enum Kind kind_;
+ const type::Type* const subtype_;
+ const uint32_t rows_;
+ const uint32_t columns_;
+};
+
+} // namespace tint::core::type
+
+#endif // SRC_TINT_LANG_CORE_TYPE_SUBGROUP_MATRIX_H_
diff --git a/src/tint/lang/core/type/subgroup_matrix_test.cc b/src/tint/lang/core/type/subgroup_matrix_test.cc
new file mode 100644
index 0000000..a1b418a
--- /dev/null
+++ b/src/tint/lang/core/type/subgroup_matrix_test.cc
@@ -0,0 +1,141 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/core/type/helper_test.h"
+
+#include "src/tint/lang/core/type/f32.h"
+#include "src/tint/lang/core/type/i8.h"
+#include "src/tint/lang/core/type/subgroup_matrix.h"
+
+namespace tint::core::type {
+namespace {
+
+using SubgroupMatrixTest = TestHelper;
+
+TEST_F(SubgroupMatrixTest, Creation) {
+ auto* f32 = create<F32>();
+
+ auto* l1 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kLeft, f32, 3u, 4u);
+
+ EXPECT_EQ(l1->Type(), f32);
+ EXPECT_EQ(l1->Kind(), SubgroupMatrix::Kind::kLeft);
+ EXPECT_EQ(l1->Rows(), 3u);
+ EXPECT_EQ(l1->Columns(), 4u);
+}
+
+TEST_F(SubgroupMatrixTest, Creation_TypeManager) {
+ core::type::Manager mgr;
+
+ {
+ auto* l = mgr.subgroup_matrix(SubgroupMatrix::Kind::kRight, mgr.f32(), 2, 4);
+ ASSERT_NE(l, nullptr);
+ EXPECT_EQ(SubgroupMatrix::Kind::kRight, l->Kind());
+ EXPECT_EQ(mgr.f32(), l->Type());
+ EXPECT_EQ(2u, l->Rows());
+ EXPECT_EQ(4u, l->Columns());
+ }
+
+ {
+ auto* l = mgr.subgroup_matrix_right(mgr.f32(), 2, 4);
+ EXPECT_EQ(SubgroupMatrix::Kind::kRight, l->Kind());
+ }
+ {
+ auto* l = mgr.subgroup_matrix_left(mgr.f32(), 2, 4);
+ EXPECT_EQ(SubgroupMatrix::Kind::kLeft, l->Kind());
+ }
+ {
+ auto* l = mgr.subgroup_matrix_result(mgr.f32(), 2, 4);
+ EXPECT_EQ(SubgroupMatrix::Kind::kResult, l->Kind());
+ }
+}
+
+TEST_F(SubgroupMatrixTest, Hash) {
+ auto* a = create<SubgroupMatrix>(SubgroupMatrix::Kind::kRight, create<I32>(), 3u, 4u);
+ auto* b = create<SubgroupMatrix>(SubgroupMatrix::Kind::kRight, create<I32>(), 3u, 4u);
+
+ EXPECT_EQ(a->unique_hash, b->unique_hash);
+}
+
+TEST_F(SubgroupMatrixTest, Equals) {
+ auto* f32 = create<F32>();
+ auto* i8 = create<I8>();
+
+ auto* l1 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kLeft, f32, 3u, 4u);
+ auto* l2 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kLeft, f32, 3u, 4u);
+ auto* l3 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kLeft, i8, 3u, 4u);
+ auto* l4 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kLeft, f32, 4u, 3u);
+
+ auto* r1 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kRight, f32, 3u, 4u);
+ auto* r2 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kRight, f32, 3u, 4u);
+ auto* res1 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kResult, f32, 3u, 4u);
+ auto* res2 = create<SubgroupMatrix>(SubgroupMatrix::Kind::kResult, f32, 3u, 4u);
+
+ EXPECT_EQ(l1, l2);
+ EXPECT_NE(l1, l3);
+ EXPECT_NE(l1, l4);
+ EXPECT_NE(l1, r1);
+ EXPECT_NE(l1, res1);
+
+ EXPECT_EQ(r1, r2);
+ EXPECT_NE(r1, res1);
+
+ EXPECT_EQ(res1, res2);
+}
+
+TEST_F(SubgroupMatrixTest, FriendlyName_Left) {
+ I8 i8;
+ SubgroupMatrix m{SubgroupMatrix::Kind::kLeft, &i8, 2, 4};
+ EXPECT_EQ(m.FriendlyName(), "subgroup_matrix_left<i8, 2, 4>");
+}
+
+TEST_F(SubgroupMatrixTest, FriendlyName_Right) {
+ F32 f32;
+ SubgroupMatrix m{SubgroupMatrix::Kind::kRight, &f32, 8, 8};
+ EXPECT_EQ(m.FriendlyName(), "subgroup_matrix_right<f32, 8, 8>");
+}
+
+TEST_F(SubgroupMatrixTest, FriendlyName_Result) {
+ U32 u32;
+ SubgroupMatrix m{SubgroupMatrix::Kind::kResult, &u32, 32, 32};
+ EXPECT_EQ(m.FriendlyName(), "subgroup_matrix_result<u32, 32, 32>");
+}
+
+TEST_F(SubgroupMatrixTest, Clone) {
+ auto* a = create<SubgroupMatrix>(SubgroupMatrix::Kind::kResult, create<I32>(), 3u, 4u);
+
+ core::type::Manager mgr;
+ core::type::CloneContext ctx{{nullptr}, {nullptr, &mgr}};
+
+ auto* s = a->Clone(ctx);
+ EXPECT_EQ(SubgroupMatrix::Kind::kResult, s->Kind());
+ EXPECT_TRUE(s->Type()->Is<I32>());
+ EXPECT_EQ(s->Rows(), 3u);
+ EXPECT_EQ(s->Columns(), 4u);
+}
+
+} // namespace
+} // namespace tint::core::type