hlsl/writer: Transpose matrices

HLSL's matrices are declared as <type>NxM, where N is the number of
rows and M is the number of columns. Despite HLSL's matrices being
column-major by default, the index operator and constructors actually
operate on row-vectors, where as WGSL operates on column vectors.
To simplify everything we use the transpose of the matrices.

This is the same approach taken by SPIRV-Cross.

Change-Id: I98860e11ff1a68132736980f694b2f68b633ef83
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46873
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/program_builder.cc b/src/program_builder.cc
index 1379730..baf39ca 100644
--- a/src/program_builder.cc
+++ b/src/program_builder.cc
@@ -140,7 +140,8 @@
 }
 
 ast::Function* ProgramBuilder::WrapInFunction(ast::StatementList stmts) {
-  return Func("test_function", {}, ty.void_(), std::move(stmts), {});
+  return Func("test_function", {}, ty.void_(), std::move(stmts),
+              {create<ast::StageDecoration>(ast::PipelineStage::kCompute)});
 }
 
 }  // namespace tint
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 32c7e81..affba86 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -370,12 +370,13 @@
       ((lhs_type->Is<type::Vector>() && rhs_type->Is<type::Matrix>()) ||
        (lhs_type->Is<type::Matrix>() && rhs_type->Is<type::Vector>()) ||
        (lhs_type->Is<type::Matrix>() && rhs_type->Is<type::Matrix>()))) {
+    // Matrices are transposed, so swap LHS and RHS.
     out << "mul(";
-    if (!EmitExpression(pre, out, expr->lhs())) {
+    if (!EmitExpression(pre, out, expr->rhs())) {
       return false;
     }
     out << ", ";
-    if (!EmitExpression(pre, out, expr->rhs())) {
+    if (!EmitExpression(pre, out, expr->lhs())) {
       return false;
     }
     out << ")";
@@ -2529,7 +2530,14 @@
     if (!EmitType(out, mat->type(), "")) {
       return false;
     }
-    out << mat->rows() << "x" << mat->columns();
+    // Note: HLSL's matrices are declared as <type>NxM, where N is the number of
+    // rows and M is the number of columns. Despite HLSL's matrices being
+    // column-major by default, the index operator and constructors actually
+    // operate on row-vectors, where as WGSL operates on column vectors.
+    // To simplify everything we use the transpose of the matrices.
+    // See:
+    // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
+    out << mat->columns() << "x" << mat->rows();
   } else if (type->Is<type::Pointer>()) {
     // TODO(dsinclair): What do we do with pointers in HLSL?
     // https://bugs.chromium.org/p/tint/issues/detail?id=183
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index deb76a0..1c56bfc 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -198,7 +198,7 @@
   GeneratorImpl& gen = Build();
 
   EXPECT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error();
-  EXPECT_EQ(result(), "mul(mat, float3(1.0f, 1.0f, 1.0f))");
+  EXPECT_EQ(result(), "mul(float3(1.0f, 1.0f, 1.0f), mat)");
 }
 
 TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
@@ -213,7 +213,7 @@
   GeneratorImpl& gen = Build();
 
   EXPECT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error();
-  EXPECT_EQ(result(), "mul(float3(1.0f, 1.0f, 1.0f), mat)");
+  EXPECT_EQ(result(), "mul(mat, float3(1.0f, 1.0f, 1.0f))");
 }
 
 TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
diff --git a/src/writer/hlsl/generator_impl_constructor_test.cc b/src/writer/hlsl/generator_impl_constructor_test.cc
index 48352be..5959354 100644
--- a/src/writer/hlsl/generator_impl_constructor_test.cc
+++ b/src/writer/hlsl/generator_impl_constructor_test.cc
@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include "gmock/gmock.h"
 #include "src/writer/hlsl/test_helper.h"
 
 namespace tint {
@@ -19,6 +20,8 @@
 namespace hlsl {
 namespace {
 
+using ::testing::HasSubstr;
+
 using HlslGeneratorImplTest_Constructor = TestHelper;
 
 TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Bool) {
@@ -113,19 +116,19 @@
 }
 
 TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat) {
-  // WGSL matrix is mat2x3 (it flips for AST, sigh). With a type constructor
-  // of <vec3, vec3>
-
-  auto* expr = mat2x3<f32>(vec3<f32>(1.f, 2.f, 3.f), vec3<f32>(3.f, 4.f, 5.f));
+  WrapInFunction(
+      mat2x3<f32>(vec3<f32>(1.f, 2.f, 3.f), vec3<f32>(3.f, 4.f, 5.f)));
 
   GeneratorImpl& gen = Build();
 
-  ASSERT_TRUE(gen.EmitConstructor(pre, out, expr)) << gen.error();
+  ASSERT_TRUE(gen.Generate(out)) << gen.error();
 
-  // A matrix of type T with n columns and m rows can also be constructed from
-  // n vectors of type T with m components.
-  EXPECT_EQ(result(),
-            "float3x2(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))");
+  EXPECT_THAT(
+      result(),
+      HasSubstr(
+          "float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))"));
+
+  Validate();
 }
 
 TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Array) {
diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc
index c8627d1..23e3fc2 100644
--- a/src/writer/hlsl/generator_impl_member_accessor_test.cc
+++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc
@@ -104,7 +104,7 @@
   // mat2x3<f32> b;
   // data.a = b;
   //
-  // -> float3x2 _tint_tmp = b;
+  // -> float2x3 _tint_tmp = b;
   //    data.Store3(4 + 0, asuint(_tint_tmp[0]));
   //    data.Store3(4 + 16, asuint(_tint_tmp[1]));
 
@@ -126,7 +126,7 @@
   gen.register_global(b_var);
 
   ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error();
-  EXPECT_EQ(result(), R"(float3x2 _tint_tmp = b;
+  EXPECT_EQ(result(), R"(float2x3 _tint_tmp = b;
 data.Store3(16 + 0, asuint(_tint_tmp[0]));
 data.Store3(16 + 16, asuint(_tint_tmp[1]));
 )");
@@ -141,7 +141,7 @@
   // var<storage> data : Data;
   // data.a = mat2x3<f32>();
   //
-  // -> float3x2 _tint_tmp = float3x2(0.0f, 0.0f, 0.0f,
+  // -> float2x3 _tint_tmp = float2x3(0.0f, 0.0f, 0.0f,
   // 0.0f, 0.0f, 0.0f);
   //    data.Store3(16 + 0, asuint(_tint_tmp[0]);
   //    data.Store3(16 + 16, asuint(_tint_tmp[1]));
@@ -164,7 +164,7 @@
   ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error();
   EXPECT_EQ(
       result(),
-      R"(float3x2 _tint_tmp = float3x2(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+      R"(float2x3 _tint_tmp = float2x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
 data.Store3(16 + 0, asuint(_tint_tmp[0]));
 data.Store3(16 + 16, asuint(_tint_tmp[1]));
 )");
diff --git a/src/writer/hlsl/generator_impl_type_test.cc b/src/writer/hlsl/generator_impl_type_test.cc
index 970260a..1b116fa 100644
--- a/src/writer/hlsl/generator_impl_type_test.cc
+++ b/src/writer/hlsl/generator_impl_type_test.cc
@@ -158,7 +158,7 @@
   GeneratorImpl& gen = Build();
 
   ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error();
-  EXPECT_EQ(result(), "float3x2");
+  EXPECT_EQ(result(), "float2x3");
 }
 
 // TODO(dsinclair): How to annotate as workgroup?
diff --git a/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc b/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc
index b5aacb3..5dae91b 100644
--- a/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc
+++ b/src/writer/hlsl/generator_impl_variable_decl_statement_test.cc
@@ -131,7 +131,7 @@
 
   ASSERT_TRUE(gen.EmitStatement(out, stmt)) << gen.error();
   EXPECT_EQ(result(),
-            R"(float3x2 a = float3x2(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
+            R"(float2x3 a = float2x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
 )");
 }
 
diff --git a/src/writer/wgsl/generator_impl_global_decl_test.cc b/src/writer/wgsl/generator_impl_global_decl_test.cc
index 19f2bdb..ddfd70d 100644
--- a/src/writer/wgsl/generator_impl_global_decl_test.cc
+++ b/src/writer/wgsl/generator_impl_global_decl_test.cc
@@ -35,7 +35,8 @@
   gen.increment_indent();
 
   ASSERT_TRUE(gen.Generate(nullptr)) << gen.error();
-  EXPECT_EQ(gen.result(), R"(  fn test_function() -> void {
+  EXPECT_EQ(gen.result(), R"(  [[stage(compute)]]
+  fn test_function() -> void {
     var a : f32;
   }