[hlsl-writer] Support matrices in storage buffers.

This CL adds the needed code to load matrix data from a storage buffer.

Bug: tint:7
Change-Id: I850b03adc7fa957b7babbad40d07ec3544b0617f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27442
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index bf66a22..ca3a445 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -58,6 +58,7 @@
 const char kOutStructNameSuffix[] = "out";
 const char kTintStructInVarPrefix[] = "tint_in";
 const char kTintStructOutVarPrefix[] = "tint_out";
+const char kTempNamePrefix[] = "_tint_tmp";
 
 bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
   if (stmts->empty()) {
@@ -1524,7 +1525,6 @@
     if (expr->IsMemberAccessor()) {
       auto* mem = expr->AsMemberAccessor();
       auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias();
-
       if (res_type->IsStruct()) {
         auto* str_type = res_type->AsStruct()->impl();
         auto* str_member = str_type->get_member(mem->member()->name());
@@ -1534,6 +1534,7 @@
           return "";
         }
         out << str_member->offset();
+
       } else if (res_type->IsVector()) {
         // This must be a single element swizzle if we've got a vector at this
         // point.
@@ -1561,7 +1562,6 @@
       auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias();
 
       out << "(";
-      // TODO(dsinclair): Handle matrix case
       if (ary_type->IsArray()) {
         out << ary_type->AsArray()->array_stride();
       } else if (ary_type->IsVector()) {
@@ -1569,6 +1569,13 @@
         // or u32 which are all 4 bytes. When we get f16 or other types we'll
         // have to ask the type for the byte size.
         out << "4";
+      } else if (ary_type->IsMatrix()) {
+        auto* mat = ary_type->AsMatrix();
+        if (mat->columns() == 2) {
+          out << "8";
+        } else {
+          out << "16";
+        }
       } else {
         error_ = "Invalid array type in storage buffer access";
         return "";
@@ -1600,14 +1607,18 @@
                                               ast::Expression* expr,
                                               ast::Expression* rhs) {
   auto* result_type = expr->result_type()->UnwrapAliasPtrAlias();
-  std::string access_method = rhs != nullptr ? "Store" : "Load";
+  bool is_store = rhs != nullptr;
+
+  std::string access_method = is_store ? "Store" : "Load";
   if (result_type->IsVector()) {
     access_method += std::to_string(result_type->AsVector()->size());
+  } else if (result_type->IsMatrix()) {
+    access_method += std::to_string(result_type->AsMatrix()->rows());
   }
 
   // If we aren't storing then we need to put in the outer cast.
-  if (rhs == nullptr) {
-    if (result_type->is_float_scalar_or_vector()) {
+  if (!is_store) {
+    if (result_type->is_float_scalar_or_vector() || result_type->IsMatrix()) {
       out << "asfloat(";
     } else if (result_type->is_signed_scalar_or_vector()) {
       out << "asint(";
@@ -1621,15 +1632,63 @@
     error_ = "error emitting storage buffer access";
     return false;
   }
-  out << buffer_name << "." << access_method << "(";
 
   auto idx = generate_storage_buffer_index_expression(expr);
   if (idx.empty()) {
     return false;
   }
-  out << idx;
 
-  if (rhs != nullptr) {
+  if (result_type->IsMatrix()) {
+    auto* mat = result_type->AsMatrix();
+
+    // TODO(dsinclair): This is assuming 4 byte elements. Will need to be fixed
+    // if we get matrixes of f16 or f64.
+    uint32_t stride = mat->rows() == 2 ? 8 : 16;
+
+    if (is_store) {
+      if (!EmitType(out, mat, "")) {
+        return false;
+      }
+
+      auto name = generate_name(kTempNamePrefix);
+      out << " " << name << " = ";
+      if (!EmitExpression(out, rhs)) {
+        return false;
+      }
+      out << ";" << std::endl;
+
+      for (uint32_t i = 0; i < mat->columns(); i++) {
+        if (i > 0) {
+          out << ";" << std::endl;
+        }
+
+        make_indent(out);
+        out << buffer_name << "." << access_method << "(" << idx << " + "
+            << (i * stride) << ", asuint(" << name << "[" << i << "]))";
+      }
+
+      return true;
+    }
+
+    out << "matrix<uint, " << mat->rows() << ", " << mat->columns() << ">(";
+
+    for (uint32_t i = 0; i < mat->columns(); i++) {
+      if (i != 0) {
+        out << ", ";
+      }
+
+      out << buffer_name << "." << access_method << "(" << idx << " + "
+          << (i * stride) << ")";
+    }
+
+    // Close the matrix type and outer cast
+    out << "))";
+
+    return true;
+  }
+
+  out << buffer_name << "." << access_method << "(" << idx;
+  if (is_store) {
     out << ", asuint(";
     if (!EmitExpression(out, rhs)) {
       return false;
@@ -1640,7 +1699,7 @@
   out << ")";
 
   // Close the outer cast.
-  if (rhs == nullptr) {
+  if (!is_store) {
     out << ")";
   }
 
diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc
index 34be490..8eec27a 100644
--- a/src/writer/hlsl/generator_impl_member_accessor_test.cc
+++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc
@@ -30,6 +30,7 @@
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
+#include "src/ast/type/matrix_type.h"
 #include "src/ast/type/struct_type.h"
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
@@ -174,6 +175,344 @@
   ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
   EXPECT_EQ(result(), "asint(data.Load(0))");
 }
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+       EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix) {
+  // struct Data {
+  //   [[offset 0]] z : f32;
+  //   [[offset 4]] a : mat2x3<f32>;
+  // };
+  // var<storage_buffer> data : Data;
+  // mat2x3<f32> b;
+  // data.a = b;
+  //
+  // -> matrix<float, 3, 2> _tint_tmp = b;
+  //    data.Store3(4 + 0, asuint(_tint_tmp[0]));
+  //    data.Store3(4 + 16, asuint(_tint_tmp[1]));
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+  ast::type::MatrixType mat(&f32, 3, 2);
+
+  ast::StructMemberList members;
+  ast::StructMemberDecorationList a_deco;
+  a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+  members.push_back(
+      std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+  ast::StructMemberDecorationList b_deco;
+  b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+  members.push_back(
+      std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+  auto str = std::make_unique<ast::Struct>();
+  str->set_members(std::move(members));
+
+  ast::type::StructType s(std::move(str));
+  s.set_name("Data");
+
+  auto b_var =
+      std::make_unique<ast::Variable>("b", ast::StorageClass::kPrivate, &mat);
+
+  auto coord_var = std::make_unique<ast::Variable>(
+      "data", ast::StorageClass::kStorageBuffer, &s);
+
+  auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+      std::make_unique<ast::IdentifierExpression>("data"),
+      std::make_unique<ast::IdentifierExpression>("a"));
+  auto rhs = std::make_unique<ast::IdentifierExpression>("b");
+
+  ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+  td().RegisterVariableForTesting(coord_var.get());
+  td().RegisterVariableForTesting(b_var.get());
+  gen().register_global(coord_var.get());
+  gen().register_global(b_var.get());
+  mod()->AddGlobalVariable(std::move(coord_var));
+  mod()->AddGlobalVariable(std::move(b_var));
+
+  ASSERT_TRUE(td().Determine()) << td().error();
+  ASSERT_TRUE(td().DetermineResultType(&assign));
+
+  ASSERT_TRUE(gen().EmitStatement(out(), &assign)) << gen().error();
+  EXPECT_EQ(result(), R"(matrix<float, 3, 2> _tint_tmp = b;
+data.Store3(4 + 0, asuint(_tint_tmp[0]));
+data.Store3(4 + 16, asuint(_tint_tmp[1]));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+       EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix_Empty) {
+  // struct Data {
+  //   [[offset 0]] z : f32;
+  //   [[offset 4]] a : mat2x3<f32>;
+  // };
+  // var<storage_buffer> data : Data;
+  // data.a = mat2x3<f32>();
+  //
+  // -> matrix<float, 3, 2> _tint_tmp = matrix<float, 3, 2>(0.0f, 0.0f, 0.0f,
+  // 0.0f, 0.0f, 0.0f);
+  //    data.Store3(4 + 0, asuint(_tint_tmp[0]);
+  //    data.Store3(4 + 16, asuint(_tint_tmp[1]));
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+  ast::type::MatrixType mat(&f32, 3, 2);
+
+  ast::StructMemberList members;
+  ast::StructMemberDecorationList a_deco;
+  a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+  members.push_back(
+      std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+  ast::StructMemberDecorationList b_deco;
+  b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+  members.push_back(
+      std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+  auto str = std::make_unique<ast::Struct>();
+  str->set_members(std::move(members));
+
+  ast::type::StructType s(std::move(str));
+  s.set_name("Data");
+
+  auto coord_var =
+      std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+          "data", ast::StorageClass::kStorageBuffer, &s));
+
+  auto lhs = std::make_unique<ast::MemberAccessorExpression>(
+      std::make_unique<ast::IdentifierExpression>("data"),
+      std::make_unique<ast::IdentifierExpression>("a"));
+  auto rhs = std::make_unique<ast::TypeConstructorExpression>(
+      &mat, ast::ExpressionList{});
+
+  ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
+
+  td().RegisterVariableForTesting(coord_var.get());
+  gen().register_global(coord_var.get());
+  mod()->AddGlobalVariable(std::move(coord_var));
+
+  ASSERT_TRUE(td().Determine()) << td().error();
+  ASSERT_TRUE(td().DetermineResultType(&assign));
+
+  ASSERT_TRUE(gen().EmitStatement(out(), &assign)) << gen().error();
+  EXPECT_EQ(
+      result(),
+      R"(matrix<float, 3, 2> _tint_tmp = matrix<float, 3, 2>(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+data.Store3(4 + 0, asuint(_tint_tmp[0]));
+data.Store3(4 + 16, asuint(_tint_tmp[1]));
+)");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+       EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix) {
+  // struct Data {
+  //   [[offset 0]] z : f32;
+  //   [[offset 4]] a : mat3x2<f32>;
+  // };
+  // var<storage_buffer> data : Data;
+  // data.a;
+  //
+  // -> asfloat(matrix<uint, 2, 3>(data.Load2(4 + 0), data.Load2(4 + 8),
+  // data.Load2(4 + 16)));
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+  ast::type::MatrixType mat(&f32, 2, 3);
+
+  ast::StructMemberList members;
+  ast::StructMemberDecorationList a_deco;
+  a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+  members.push_back(
+      std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+  ast::StructMemberDecorationList b_deco;
+  b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+  members.push_back(
+      std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+  auto str = std::make_unique<ast::Struct>();
+  str->set_members(std::move(members));
+
+  ast::type::StructType s(std::move(str));
+  s.set_name("Data");
+
+  auto coord_var =
+      std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+          "data", ast::StorageClass::kStorageBuffer, &s));
+
+  ast::MemberAccessorExpression expr(
+      std::make_unique<ast::IdentifierExpression>("data"),
+      std::make_unique<ast::IdentifierExpression>("a"));
+
+  td().RegisterVariableForTesting(coord_var.get());
+  gen().register_global(coord_var.get());
+  mod()->AddGlobalVariable(std::move(coord_var));
+
+  ASSERT_TRUE(td().Determine()) << td().error();
+  ASSERT_TRUE(td().DetermineResultType(&expr));
+
+  ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
+  EXPECT_EQ(result(),
+            "asfloat(matrix<uint, 2, 3>(data.Load2(4 + 0), data.Load2(4 + 8), "
+            "data.Load2(4 + 16)))");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+       EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Nested) {
+  // struct Data {
+  //   [[offset 0]] z : f32;
+  //   [[offset 4]] a : mat2x3<f32;
+  // };
+  // struct Outer {
+  //   [[offset 0]] c : f32;
+  //   [[offset 4]] b : Data;
+  // };
+  // var<storage_buffer> data : Outer;
+  // data.b.a;
+  //
+  // -> asfloat(matrix<uint, 3, 2>(data.Load3(4 + 0), data.Load3(4 + 16)));
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+  ast::type::MatrixType mat(&f32, 3, 2);
+
+  ast::StructMemberList members;
+  ast::StructMemberDecorationList a_deco;
+  a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+  members.push_back(
+      std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+  ast::StructMemberDecorationList b_deco;
+  b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
+  members.push_back(
+      std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+  auto str = std::make_unique<ast::Struct>();
+  str->set_members(std::move(members));
+
+  ast::type::StructType s(std::move(str));
+  s.set_name("Data");
+
+  auto coord_var =
+      std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+          "data", ast::StorageClass::kStorageBuffer, &s));
+
+  ast::MemberAccessorExpression expr(
+      std::make_unique<ast::IdentifierExpression>("data"),
+      std::make_unique<ast::IdentifierExpression>("a"));
+
+  td().RegisterVariableForTesting(coord_var.get());
+  gen().register_global(coord_var.get());
+  mod()->AddGlobalVariable(std::move(coord_var));
+
+  ASSERT_TRUE(td().Determine()) << td().error();
+  ASSERT_TRUE(td().DetermineResultType(&expr));
+
+  ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
+  EXPECT_EQ(
+      result(),
+      "asfloat(matrix<uint, 3, 2>(data.Load3(4 + 0), data.Load3(4 + 16)))");
+}
+
+TEST_F(
+    HlslGeneratorImplTest_MemberAccessor,
+    EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_By3_Is_16_Bytes) {
+  // struct Data {
+  //   [[offset 4]] a : mat3x3<f32;
+  // };
+  // var<storage_buffer> data : Data;
+  // data.a;
+  //
+  // -> asfloat(matrix<uint, 3, 3>(data.Load3(0), data.Load3(16),
+  // data.Load3(32)));
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+  ast::type::MatrixType mat(&f32, 3, 3);
+
+  ast::StructMemberList members;
+  ast::StructMemberDecorationList deco;
+  deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+  members.push_back(
+      std::make_unique<ast::StructMember>("a", &mat, std::move(deco)));
+
+  auto str = std::make_unique<ast::Struct>();
+  str->set_members(std::move(members));
+
+  ast::type::StructType s(std::move(str));
+  s.set_name("Data");
+
+  auto coord_var =
+      std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+          "data", ast::StorageClass::kStorageBuffer, &s));
+
+  ast::MemberAccessorExpression expr(
+      std::make_unique<ast::IdentifierExpression>("data"),
+      std::make_unique<ast::IdentifierExpression>("a"));
+
+  td().RegisterVariableForTesting(coord_var.get());
+  gen().register_global(coord_var.get());
+  mod()->AddGlobalVariable(std::move(coord_var));
+
+  ASSERT_TRUE(td().Determine()) << td().error();
+  ASSERT_TRUE(td().DetermineResultType(&expr));
+
+  ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
+  EXPECT_EQ(result(),
+            "asfloat(matrix<uint, 3, 3>(data.Load3(0 + 0), data.Load3(0 + 16), "
+            "data.Load3(0 + 32)))");
+}
+
+TEST_F(HlslGeneratorImplTest_MemberAccessor,
+       EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Single_Element) {
+  // struct Data {
+  //   [[offset 0]] z : f32;
+  //   [[offset 16]] a : mat4x3<f32>;
+  // };
+  // var<storage_buffer> data : Data;
+  // data.a[2][1];
+  //
+  // -> asfloat(data.Load((2 * 16) + (1 * 4) + 16)))
+  ast::type::F32Type f32;
+  ast::type::I32Type i32;
+  ast::type::MatrixType mat(&f32, 3, 4);
+
+  ast::StructMemberList members;
+  ast::StructMemberDecorationList a_deco;
+  a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
+  members.push_back(
+      std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
+
+  ast::StructMemberDecorationList b_deco;
+  b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
+  members.push_back(
+      std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
+
+  auto str = std::make_unique<ast::Struct>();
+  str->set_members(std::move(members));
+
+  ast::type::StructType s(std::move(str));
+  s.set_name("Data");
+
+  auto coord_var =
+      std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
+          "data", ast::StorageClass::kStorageBuffer, &s));
+
+  ast::ArrayAccessorExpression expr(
+      std::make_unique<ast::ArrayAccessorExpression>(
+          std::make_unique<ast::MemberAccessorExpression>(
+              std::make_unique<ast::IdentifierExpression>("data"),
+              std::make_unique<ast::IdentifierExpression>("a")),
+          std::make_unique<ast::ScalarConstructorExpression>(
+              std::make_unique<ast::SintLiteral>(&i32, 2))),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 1)));
+
+  td().RegisterVariableForTesting(coord_var.get());
+  gen().register_global(coord_var.get());
+  mod()->AddGlobalVariable(std::move(coord_var));
+
+  ASSERT_TRUE(td().Determine()) << td().error();
+  ASSERT_TRUE(td().DetermineResultType(&expr));
+
+  ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
+  EXPECT_EQ(result(), "asfloat(data.Load((4 * 1) + (16 * 2) + 16))");
+}
 
 TEST_F(HlslGeneratorImplTest_MemberAccessor,
        EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray) {