[hlsl-writer] Support matrices in storage buffers.
This CL adds the needed code to load matrix data from a storage buffer.
Bug: tint:7
Change-Id: I850b03adc7fa957b7babbad40d07ec3544b0617f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27442
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index bf66a22..ca3a445 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -58,6 +58,7 @@
const char kOutStructNameSuffix[] = "out";
const char kTintStructInVarPrefix[] = "tint_in";
const char kTintStructOutVarPrefix[] = "tint_out";
+const char kTempNamePrefix[] = "_tint_tmp";
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
if (stmts->empty()) {
@@ -1524,7 +1525,6 @@
if (expr->IsMemberAccessor()) {
auto* mem = expr->AsMemberAccessor();
auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias();
-
if (res_type->IsStruct()) {
auto* str_type = res_type->AsStruct()->impl();
auto* str_member = str_type->get_member(mem->member()->name());
@@ -1534,6 +1534,7 @@
return "";
}
out << str_member->offset();
+
} else if (res_type->IsVector()) {
// This must be a single element swizzle if we've got a vector at this
// point.
@@ -1561,7 +1562,6 @@
auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias();
out << "(";
- // TODO(dsinclair): Handle matrix case
if (ary_type->IsArray()) {
out << ary_type->AsArray()->array_stride();
} else if (ary_type->IsVector()) {
@@ -1569,6 +1569,13 @@
// or u32 which are all 4 bytes. When we get f16 or other types we'll
// have to ask the type for the byte size.
out << "4";
+ } else if (ary_type->IsMatrix()) {
+ auto* mat = ary_type->AsMatrix();
+ if (mat->columns() == 2) {
+ out << "8";
+ } else {
+ out << "16";
+ }
} else {
error_ = "Invalid array type in storage buffer access";
return "";
@@ -1600,14 +1607,18 @@
ast::Expression* expr,
ast::Expression* rhs) {
auto* result_type = expr->result_type()->UnwrapAliasPtrAlias();
- std::string access_method = rhs != nullptr ? "Store" : "Load";
+ bool is_store = rhs != nullptr;
+
+ std::string access_method = is_store ? "Store" : "Load";
if (result_type->IsVector()) {
access_method += std::to_string(result_type->AsVector()->size());
+ } else if (result_type->IsMatrix()) {
+ access_method += std::to_string(result_type->AsMatrix()->rows());
}
// If we aren't storing then we need to put in the outer cast.
- if (rhs == nullptr) {
- if (result_type->is_float_scalar_or_vector()) {
+ if (!is_store) {
+ if (result_type->is_float_scalar_or_vector() || result_type->IsMatrix()) {
out << "asfloat(";
} else if (result_type->is_signed_scalar_or_vector()) {
out << "asint(";
@@ -1621,15 +1632,63 @@
error_ = "error emitting storage buffer access";
return false;
}
- out << buffer_name << "." << access_method << "(";
auto idx = generate_storage_buffer_index_expression(expr);
if (idx.empty()) {
return false;
}
- out << idx;
- if (rhs != nullptr) {
+ if (result_type->IsMatrix()) {
+ auto* mat = result_type->AsMatrix();
+
+ // TODO(dsinclair): This is assuming 4 byte elements. Will need to be fixed
+ // if we get matrixes of f16 or f64.
+ uint32_t stride = mat->rows() == 2 ? 8 : 16;
+
+ if (is_store) {
+ if (!EmitType(out, mat, "")) {
+ return false;
+ }
+
+ auto name = generate_name(kTempNamePrefix);
+ out << " " << name << " = ";
+ if (!EmitExpression(out, rhs)) {
+ return false;
+ }
+ out << ";" << std::endl;
+
+ for (uint32_t i = 0; i < mat->columns(); i++) {
+ if (i > 0) {
+ out << ";" << std::endl;
+ }
+
+ make_indent(out);
+ out << buffer_name << "." << access_method << "(" << idx << " + "
+ << (i * stride) << ", asuint(" << name << "[" << i << "]))";
+ }
+
+ return true;
+ }
+
+ out << "matrix<uint, " << mat->rows() << ", " << mat->columns() << ">(";
+
+ for (uint32_t i = 0; i < mat->columns(); i++) {
+ if (i != 0) {
+ out << ", ";
+ }
+
+ out << buffer_name << "." << access_method << "(" << idx << " + "
+ << (i * stride) << ")";
+ }
+
+ // Close the matrix type and outer cast
+ out << "))";
+
+ return true;
+ }
+
+ out << buffer_name << "." << access_method << "(" << idx;
+ if (is_store) {
out << ", asuint(";
if (!EmitExpression(out, rhs)) {
return false;
@@ -1640,7 +1699,7 @@
out << ")";
// Close the outer cast.
- if (rhs == nullptr) {
+ if (!is_store) {
out << ")";
}
diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc
index 34be490..8eec27a 100644
--- a/src/writer/hlsl/generator_impl_member_accessor_test.cc
+++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc
@@ -30,6 +30,7 @@
#include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
+#include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
@@ -174,6 +175,344 @@
ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
EXPECT_EQ(result(), "asint(data.Load(0))");
}
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+ EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix) {
+ // struct Data {
+ // [[offset 0]] z : f32;
+ // [[offset 4]] a : mat2x3<f32>;
+ // };
+ // var<storage_buffer> data : Data;
+ // mat2x3<f32> b;
+ // data.a = b;
+ //
+ // -> matrix<float, 3, 2> _tint_tmp = b;
+ // data.Store3(4 + 0, asuint(_tint_tmp[0]));
+ // data.Store3(4 + 16, asuint(_tint_tmp[1]));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::MatrixType mat(&f32, 3, 2);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto b_var =
+ std::make_unique<ast::Variable>("b", ast::StorageClass::kPrivate, &mat);
+
+ auto coord_var = std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s);
+
+ auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a"));
+ auto rhs = std::make_unique<ast::IdentifierExpression>("b");
+
+ ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+ td().RegisterVariableForTesting(coord_var.get());
+ td().RegisterVariableForTesting(b_var.get());
+ gen().register_global(coord_var.get());
+ gen().register_global(b_var.get());
+ mod()->AddGlobalVariable(std::move(coord_var));
+ mod()->AddGlobalVariable(std::move(b_var));
+
+ ASSERT_TRUE(td().Determine()) << td().error();
+ ASSERT_TRUE(td().DetermineResultType(&assign));
+
+ ASSERT_TRUE(gen().EmitStatement(out(), &assign)) << gen().error();
+ EXPECT_EQ(result(), R"(matrix<float, 3, 2> _tint_tmp = b;
+data.Store3(4 + 0, asuint(_tint_tmp[0]));
+data.Store3(4 + 16, asuint(_tint_tmp[1]));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+ EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix_Empty) {
+ // struct Data {
+ // [[offset 0]] z : f32;
+ // [[offset 4]] a : mat2x3<f32>;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a = mat2x3<f32>();
+ //
+ // -> matrix<float, 3, 2> _tint_tmp = matrix<float, 3, 2>(0.0f, 0.0f, 0.0f,
+ // 0.0f, 0.0f, 0.0f);
+ // data.Store3(4 + 0, asuint(_tint_tmp[0]);
+ // data.Store3(4 + 16, asuint(_tint_tmp[1]));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::MatrixType mat(&f32, 3, 2);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a"));
+ auto rhs = std::make_unique<ast::TypeConstructorExpression>(
+ &mat, ast::ExpressionList{});
+
+ ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+ td().RegisterVariableForTesting(coord_var.get());
+ gen().register_global(coord_var.get());
+ mod()->AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td().Determine()) << td().error();
+ ASSERT_TRUE(td().DetermineResultType(&assign));
+
+ ASSERT_TRUE(gen().EmitStatement(out(), &assign)) << gen().error();
+ EXPECT_EQ(
+ result(),
+ R"(matrix<float, 3, 2> _tint_tmp = matrix<float, 3, 2>(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+data.Store3(4 + 0, asuint(_tint_tmp[0]));
+data.Store3(4 + 16, asuint(_tint_tmp[1]));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix) {
+ // struct Data {
+ // [[offset 0]] z : f32;
+ // [[offset 4]] a : mat3x2<f32>;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a;
+ //
+ // -> asfloat(matrix<uint, 2, 3>(data.Load2(4 + 0), data.Load2(4 + 8),
+ // data.Load2(4 + 16)));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::MatrixType mat(&f32, 2, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a"));
+
+ td().RegisterVariableForTesting(coord_var.get());
+ gen().register_global(coord_var.get());
+ mod()->AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td().Determine()) << td().error();
+ ASSERT_TRUE(td().DetermineResultType(&expr));
+
+ ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
+ EXPECT_EQ(result(),
+ "asfloat(matrix<uint, 2, 3>(data.Load2(4 + 0), data.Load2(4 + 8), "
+ "data.Load2(4 + 16)))");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Nested) {
+ // struct Data {
+ // [[offset 0]] z : f32;
+ // [[offset 4]] a : mat2x3<f32;
+ // };
+ // struct Outer {
+ // [[offset 0]] c : f32;
+ // [[offset 4]] b : Data;
+ // };
+ // var<storage_buffer> data : Outer;
+ // data.b.a;
+ //
+ // -> asfloat(matrix<uint, 3, 2>(data.Load3(4 + 0), data.Load3(4 + 16)));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::MatrixType mat(&f32, 3, 2);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a"));
+
+ td().RegisterVariableForTesting(coord_var.get());
+ gen().register_global(coord_var.get());
+ mod()->AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td().Determine()) << td().error();
+ ASSERT_TRUE(td().DetermineResultType(&expr));
+
+ ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
+ EXPECT_EQ(
+ result(),
+ "asfloat(matrix<uint, 3, 2>(data.Load3(4 + 0), data.Load3(4 + 16)))");
+}
+
+TEST_F(
+ HlslGeneratorImplTest_MemberAccessor,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_By3_Is_16_Bytes) {
+ // struct Data {
+ // [[offset 4]] a : mat3x3<f32;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a;
+ //
+ // -> asfloat(matrix<uint, 3, 3>(data.Load3(0), data.Load3(16),
+ // data.Load3(32)));
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::MatrixType mat(&f32, 3, 3);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList deco;
+ deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &mat, std::move(deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::MemberAccessorExpression expr(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a"));
+
+ td().RegisterVariableForTesting(coord_var.get());
+ gen().register_global(coord_var.get());
+ mod()->AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td().Determine()) << td().error();
+ ASSERT_TRUE(td().DetermineResultType(&expr));
+
+ ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
+ EXPECT_EQ(result(),
+ "asfloat(matrix<uint, 3, 3>(data.Load3(0 + 0), data.Load3(0 + 16), "
+ "data.Load3(0 + 32)))");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+ EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Single_Element) {
+ // struct Data {
+ // [[offset 0]] z : f32;
+ // [[offset 16]] a : mat4x3<f32>;
+ // };
+ // var<storage_buffer> data : Data;
+ // data.a[2][1];
+ //
+ // -> asfloat(data.Load((2 * 16) + (1 * 4) + 16)))
+ ast::type::F32Type f32;
+ ast::type::I32Type i32;
+ ast::type::MatrixType mat(&f32, 3, 4);
+
+ ast::StructMemberList members;
+ ast::StructMemberDecorationList a_deco;
+ a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+ members.push_back(
+ std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+ ast::StructMemberDecorationList b_deco;
+ b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+ members.push_back(
+ std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+ auto str = std::make_unique<ast::Struct>();
+ str->set_members(std::move(members));
+
+ ast::type::StructType s(std::move(str));
+ s.set_name("Data");
+
+ auto coord_var =
+ std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+ "data", ast::StorageClass::kStorageBuffer, &s));
+
+ ast::ArrayAccessorExpression expr(
+ std::make_unique<ast::ArrayAccessorExpression>(
+ std::make_unique<ast::MemberAccessorExpression>(
+ std::make_unique<ast::IdentifierExpression>("data"),
+ std::make_unique<ast::IdentifierExpression>("a")),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 2))),
+ std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 1)));
+
+ td().RegisterVariableForTesting(coord_var.get());
+ gen().register_global(coord_var.get());
+ mod()->AddGlobalVariable(std::move(coord_var));
+
+ ASSERT_TRUE(td().Determine()) << td().error();
+ ASSERT_TRUE(td().DetermineResultType(&expr));
+
+ ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
+ EXPECT_EQ(result(), "asfloat(data.Load((4 * 1) + (16 * 2) + 16))");
+}
TEST_F(HlslGeneratorImplTest_MemberAccessor,
EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray) {