[hlsl-writer] Use `mul` method where required.

This CL updates the binary operator emission to use the `mul()` method
in the following cases:
 - vector * matrix
 - matrix * vector
 - matrix * matrix

This is because the `*` operator works per-component in HLSL which does
not do the expected multiply.

Bug: tint:301
Change-Id: I0810522ac26fbbea323cf8a05a3ff6f2fb62117e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33362
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 2fa9d83..039217d 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -365,6 +365,27 @@
     return true;
   }
 
+  auto* lhs_type = expr->lhs()->result_type()->UnwrapAll();
+  auto* rhs_type = expr->rhs()->result_type()->UnwrapAll();
+  // Multiplying by a matrix requires the use of `mul` in order to get the
+  // type of multiply we desire.
+  if (expr->op() == ast::BinaryOp::kMultiply &&
+      ((lhs_type->IsVector() && rhs_type->IsMatrix()) ||
+       (lhs_type->IsMatrix() && rhs_type->IsVector()) ||
+       (lhs_type->IsMatrix() && rhs_type->IsMatrix()))) {
+    out << "mul(";
+    if (!EmitExpression(pre, out, expr->lhs())) {
+      return false;
+    }
+    out << ", ";
+    if (!EmitExpression(pre, out, expr->rhs())) {
+      return false;
+    }
+    out << ")";
+
+    return true;
+  }
+
   out << "(";
   if (!EmitExpression(pre, out, expr->lhs())) {
     return false;
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index 23d76c5..baf428d 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -20,6 +20,7 @@
 #include "src/ast/call_expression.h"
 #include "src/ast/call_statement.h"
 #include "src/ast/else_statement.h"
+#include "src/ast/float_literal.h"
 #include "src/ast/function.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/if_statement.h"
@@ -28,8 +29,13 @@
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/sint_literal.h"
 #include "src/ast/type/bool_type.h"
+#include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
+#include "src/ast/type/matrix_type.h"
+#include "src/ast/type/u32_type.h"
+#include "src/ast/type/vector_type.h"
 #include "src/ast/type/void_type.h"
+#include "src/ast/type_constructor_expression.h"
 #include "src/ast/variable.h"
 #include "src/ast/variable_decl_statement.h"
 #include "src/writer/hlsl/test_helper.h"
@@ -51,14 +57,69 @@
 }
 
 using HlslBinaryTest = TestParamHelper<BinaryData>;
-TEST_P(HlslBinaryTest, Emit) {
+TEST_P(HlslBinaryTest, Emit_f32) {
+  ast::type::F32Type f32;
+
   auto params = GetParam();
 
+  auto* left_var =
+      create<ast::Variable>("left", ast::StorageClass::kFunction, &f32);
+  auto* right_var =
+      create<ast::Variable>("right", ast::StorageClass::kFunction, &f32);
+
   auto* left = create<ast::IdentifierExpression>("left");
   auto* right = create<ast::IdentifierExpression>("right");
 
+  td.RegisterVariableForTesting(left_var);
+  td.RegisterVariableForTesting(right_var);
+
   ast::BinaryExpression expr(params.op, left, right);
 
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(), params.result);
+}
+TEST_P(HlslBinaryTest, Emit_u32) {
+  ast::type::U32Type u32;
+
+  auto params = GetParam();
+
+  auto* left_var =
+      create<ast::Variable>("left", ast::StorageClass::kFunction, &u32);
+  auto* right_var =
+      create<ast::Variable>("right", ast::StorageClass::kFunction, &u32);
+
+  auto* left = create<ast::IdentifierExpression>("left");
+  auto* right = create<ast::IdentifierExpression>("right");
+
+  td.RegisterVariableForTesting(left_var);
+  td.RegisterVariableForTesting(right_var);
+
+  ast::BinaryExpression expr(params.op, left, right);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(), params.result);
+}
+TEST_P(HlslBinaryTest, Emit_i32) {
+  ast::type::I32Type i32;
+
+  auto params = GetParam();
+
+  auto* left_var =
+      create<ast::Variable>("left", ast::StorageClass::kFunction, &i32);
+  auto* right_var =
+      create<ast::Variable>("right", ast::StorageClass::kFunction, &i32);
+
+  auto* left = create<ast::IdentifierExpression>("left");
+  auto* right = create<ast::IdentifierExpression>("right");
+
+  td.RegisterVariableForTesting(left_var);
+  td.RegisterVariableForTesting(right_var);
+
+  ast::BinaryExpression expr(params.op, left, right);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
   ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
   EXPECT_EQ(result(), params.result);
 }
@@ -83,6 +144,166 @@
         BinaryData{"(left / right)", ast::BinaryOp::kDivide},
         BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
 
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto* lhs = create<ast::TypeConstructorExpression>(
+      &vec3, ast::ExpressionList{
+                 create<ast::ScalarConstructorExpression>(
+                     create<ast::FloatLiteral>(&f32, 1.f)),
+                 create<ast::ScalarConstructorExpression>(
+                     create<ast::FloatLiteral>(&f32, 1.f)),
+                 create<ast::ScalarConstructorExpression>(
+                     create<ast::FloatLiteral>(&f32, 1.f)),
+             });
+
+  auto* rhs = create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f));
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(),
+            "(float3(1.00000000f, 1.00000000f, 1.00000000f) * "
+            "1.00000000f)");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto* lhs = create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f));
+
+  ast::ExpressionList vals;
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(),
+            "(1.00000000f * float3(1.00000000f, 1.00000000f, "
+            "1.00000000f))");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
+  ast::type::F32Type f32;
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+  auto* lhs = create<ast::IdentifierExpression>("mat");
+  auto* rhs = create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f));
+
+  td.RegisterVariableForTesting(var);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(), "(mat * 1.00000000f)");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
+  ast::type::F32Type f32;
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+  auto* lhs = create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f));
+  auto* rhs = create<ast::IdentifierExpression>("mat");
+
+  td.RegisterVariableForTesting(var);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(), "(1.00000000f * mat)");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+  auto* lhs = create<ast::IdentifierExpression>("mat");
+
+  ast::ExpressionList vals;
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
+
+  td.RegisterVariableForTesting(var);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(),
+            "mul(mat, float3(1.00000000f, 1.00000000f, 1.00000000f))");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+
+  ast::ExpressionList vals;
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  vals.push_back(create<ast::ScalarConstructorExpression>(
+      create<ast::FloatLiteral>(&f32, 1.f)));
+  auto* lhs = create<ast::TypeConstructorExpression>(&vec3, vals);
+
+  auto* rhs = create<ast::IdentifierExpression>("mat");
+
+  td.RegisterVariableForTesting(var);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(),
+            "mul(float3(1.00000000f, 1.00000000f, 1.00000000f), mat)");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::MatrixType mat3(&f32, 3, 3);
+
+  auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+  auto* lhs = create<ast::IdentifierExpression>("mat");
+  auto* rhs = create<ast::IdentifierExpression>("mat");
+
+  td.RegisterVariableForTesting(var);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+  EXPECT_EQ(result(), "mul(mat, mat)");
+}
+
 TEST_F(HlslGeneratorImplTest_Binary, Logical_And) {
   auto* left = create<ast::IdentifierExpression>("left");
   auto* right = create<ast::IdentifierExpression>("right");