[glsl][ir] Emit Matrix types
This CL adds emission of matrix types to the GLSL IR backend.
Bug: 42251044
Change-Id: I5ab3ecefd66c6e7ab74efb492ab3b8b4ea6555cb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/204237
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/tint/lang/glsl/writer/printer/printer.cc b/src/tint/lang/glsl/writer/printer/printer.cc
index 958ca60..a499f72 100644
--- a/src/tint/lang/glsl/writer/printer/printer.cc
+++ b/src/tint/lang/glsl/writer/printer/printer.cc
@@ -54,6 +54,7 @@
#include "src/tint/lang/core/type/f16.h"
#include "src/tint/lang/core/type/f32.h"
#include "src/tint/lang/core/type/i32.h"
+#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/core/type/pointer.h"
#include "src/tint/lang/core/type/u32.h"
#include "src/tint/lang/core/type/vector.h"
@@ -305,6 +306,7 @@
EmitType(out, p->StoreType(), name, name_printed);
},
[&](const core::type::Vector* v) { EmitVectorType(out, v); },
+ [&](const core::type::Matrix* m) { EmitMatrixType(out, m); },
// TODO(dsinclair): Handle remaining types
TINT_ICE_ON_NO_MATCH);
@@ -326,6 +328,17 @@
out << "vec" << v->Width();
}
+ void EmitMatrixType(StringStream& out, const core::type::Matrix* m) {
+ if (m->Type()->Is<core::type::F16>()) {
+ EmitExtension(kAMDGpuShaderHalfFloat);
+ out << "f16";
+ }
+ out << "mat" << m->Columns();
+ if (m->Rows() != m->Columns()) {
+ out << "x" << m->Rows();
+ }
+ }
+
void EmitArrayType(StringStream& out,
const core::type::Array* ary,
const std::string& name,
@@ -436,6 +449,7 @@
[&](const core::type::F32*) { PrintF32(out, c->ValueAs<f32>()); },
[&](const core::type::F16*) { PrintF16(out, c->ValueAs<f16>()); },
[&](const core::type::Vector* v) { EmitConstantVector(out, v, c); },
+ [&](const core::type::Matrix* m) { EmitConstantMatrix(out, m, c); },
// TODO(dsinclair): Emit remaining constant types
TINT_ICE_ON_NO_MATCH);
@@ -461,6 +475,20 @@
}
}
+ void EmitConstantMatrix(StringStream& out,
+ const core::type::Matrix* m,
+ const core::constant::Value* c) {
+ EmitType(out, m);
+ ScopedParen sp(out);
+
+ for (size_t col_idx = 0; col_idx < m->Columns(); ++col_idx) {
+ if (col_idx > 0) {
+ out << ", ";
+ }
+ EmitConstant(out, c->Index(col_idx));
+ }
+ }
+
void EmitConstantArray(StringStream& out,
const core::type::Array* ary,
const core::constant::Value* c) {
diff --git a/src/tint/lang/glsl/writer/type_test.cc b/src/tint/lang/glsl/writer/type_test.cc
index ae4f717..14a21a7 100644
--- a/src/tint/lang/glsl/writer/type_test.cc
+++ b/src/tint/lang/glsl/writer/type_test.cc
@@ -165,8 +165,7 @@
)");
}
-// TODO(dsinclair): Add matrix support
-TEST_F(GlslWriterTest, DISABLED_EmitType_Matrix_F32) {
+TEST_F(GlslWriterTest, EmitType_Matrix_F32) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
func->SetWorkgroupSize(1, 1, 1);
b.Append(func->Block(), [&] {
@@ -178,13 +177,29 @@
EXPECT_EQ(output_.glsl, GlslHeader() + R"(
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void foo() {
- mat2x3 a = mat2x3(0.0f);
+ mat2x3 a = mat2x3(vec3(0.0f), vec3(0.0f));
}
)");
}
-// TODO(dsinclair): Add matrix support
-TEST_F(GlslWriterTest, DISABLED_EmitType_Matrix_F16) {
+TEST_F(GlslWriterTest, EmitType_MatrixSquare_F32) {
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
+ func->SetWorkgroupSize(1, 1, 1);
+ b.Append(func->Block(), [&] {
+ b.Var("a", ty.ptr(core::AddressSpace::kPrivate, ty.mat2x2<f32>()));
+ b.Return(func);
+ });
+
+ ASSERT_TRUE(Generate()) << err_ << output_.glsl;
+ EXPECT_EQ(output_.glsl, GlslHeader() + R"(
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+void foo() {
+ mat2 a = mat2(vec2(0.0f), vec2(0.0f));
+}
+)");
+}
+
+TEST_F(GlslWriterTest, EmitType_Matrix_F16) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute);
func->SetWorkgroupSize(1, 1, 1);
b.Append(func->Block(), [&] {
@@ -193,10 +208,11 @@
});
ASSERT_TRUE(Generate()) << err_ << output_.glsl;
- EXPECT_EQ(output_.glsl, GlslHeader() + R"(
+ EXPECT_EQ(output_.glsl, GlslHeader() + R"(#extension GL_AMD_gpu_shader_half_float: require
+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void foo() {
- f16mat2x3 a = f16mat2x3(0.0h);
+ f16mat2x3 a = f16mat2x3(f16vec3(0.0hf), f16vec3(0.0hf));
}
)");
}