[hlsl-writer] Use `mul` method where required.
This CL updates the binary operator emission to use the `mul()` method
in the following cases:
- vector * matrix
- matrix * vector
- matrix * matrix
This is because the `*` operator works per-component in HLSL which does
not do the expected multiply.
Bug: tint:301
Change-Id: I0810522ac26fbbea323cf8a05a3ff6f2fb62117e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33362
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 2fa9d83..039217d 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -365,6 +365,27 @@
return true;
}
+ auto* lhs_type = expr->lhs()->result_type()->UnwrapAll();
+ auto* rhs_type = expr->rhs()->result_type()->UnwrapAll();
+ // Multiplying by a matrix requires the use of `mul` in order to get the
+ // type of multiply we desire.
+ if (expr->op() == ast::BinaryOp::kMultiply &&
+ ((lhs_type->IsVector() && rhs_type->IsMatrix()) ||
+ (lhs_type->IsMatrix() && rhs_type->IsVector()) ||
+ (lhs_type->IsMatrix() && rhs_type->IsMatrix()))) {
+ out << "mul(";
+ if (!EmitExpression(pre, out, expr->lhs())) {
+ return false;
+ }
+ out << ", ";
+ if (!EmitExpression(pre, out, expr->rhs())) {
+ return false;
+ }
+ out << ")";
+
+ return true;
+ }
+
out << "(";
if (!EmitExpression(pre, out, expr->lhs())) {
return false;
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index 23d76c5..baf428d 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -20,6 +20,7 @@
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
#include "src/ast/else_statement.h"
+#include "src/ast/float_literal.h"
#include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
@@ -28,8 +29,13 @@
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
#include "src/ast/type/bool_type.h"
+#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
+#include "src/ast/type/matrix_type.h"
+#include "src/ast/type/u32_type.h"
+#include "src/ast/type/vector_type.h"
#include "src/ast/type/void_type.h"
+#include "src/ast/type_constructor_expression.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/writer/hlsl/test_helper.h"
@@ -51,14 +57,69 @@
}
using HlslBinaryTest = TestParamHelper<BinaryData>;
-TEST_P(HlslBinaryTest, Emit) {
+TEST_P(HlslBinaryTest, Emit_f32) {
+ ast::type::F32Type f32;
+
auto params = GetParam();
+ auto* left_var =
+ create<ast::Variable>("left", ast::StorageClass::kFunction, &f32);
+ auto* right_var =
+ create<ast::Variable>("right", ast::StorageClass::kFunction, &f32);
+
auto* left = create<ast::IdentifierExpression>("left");
auto* right = create<ast::IdentifierExpression>("right");
+ td.RegisterVariableForTesting(left_var);
+ td.RegisterVariableForTesting(right_var);
+
ast::BinaryExpression expr(params.op, left, right);
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(), params.result);
+}
+TEST_P(HlslBinaryTest, Emit_u32) {
+ ast::type::U32Type u32;
+
+ auto params = GetParam();
+
+ auto* left_var =
+ create<ast::Variable>("left", ast::StorageClass::kFunction, &u32);
+ auto* right_var =
+ create<ast::Variable>("right", ast::StorageClass::kFunction, &u32);
+
+ auto* left = create<ast::IdentifierExpression>("left");
+ auto* right = create<ast::IdentifierExpression>("right");
+
+ td.RegisterVariableForTesting(left_var);
+ td.RegisterVariableForTesting(right_var);
+
+ ast::BinaryExpression expr(params.op, left, right);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(), params.result);
+}
+TEST_P(HlslBinaryTest, Emit_i32) {
+ ast::type::I32Type i32;
+
+ auto params = GetParam();
+
+ auto* left_var =
+ create<ast::Variable>("left", ast::StorageClass::kFunction, &i32);
+ auto* right_var =
+ create<ast::Variable>("right", ast::StorageClass::kFunction, &i32);
+
+ auto* left = create<ast::IdentifierExpression>("left");
+ auto* right = create<ast::IdentifierExpression>("right");
+
+ td.RegisterVariableForTesting(left_var);
+ td.RegisterVariableForTesting(right_var);
+
+ ast::BinaryExpression expr(params.op, left, right);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(), params.result);
}
@@ -83,6 +144,166 @@
BinaryData{"(left / right)", ast::BinaryOp::kDivide},
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto* lhs = create<ast::TypeConstructorExpression>(
+ &vec3, ast::ExpressionList{
+ create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)),
+ create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)),
+ create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)),
+ });
+
+ auto* rhs = create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f));
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(),
+ "(float3(1.00000000f, 1.00000000f, 1.00000000f) * "
+ "1.00000000f)");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto* lhs = create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f));
+
+ ast::ExpressionList vals;
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(),
+ "(1.00000000f * float3(1.00000000f, 1.00000000f, "
+ "1.00000000f))");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
+ ast::type::F32Type f32;
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+ auto* lhs = create<ast::IdentifierExpression>("mat");
+ auto* rhs = create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f));
+
+ td.RegisterVariableForTesting(var);
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(), "(mat * 1.00000000f)");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
+ ast::type::F32Type f32;
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+ auto* lhs = create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f));
+ auto* rhs = create<ast::IdentifierExpression>("mat");
+
+ td.RegisterVariableForTesting(var);
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(), "(1.00000000f * mat)");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+ auto* lhs = create<ast::IdentifierExpression>("mat");
+
+ ast::ExpressionList vals;
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
+
+ td.RegisterVariableForTesting(var);
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(),
+ "mul(mat, float3(1.00000000f, 1.00000000f, 1.00000000f))");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+
+ ast::ExpressionList vals;
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ vals.push_back(create<ast::ScalarConstructorExpression>(
+ create<ast::FloatLiteral>(&f32, 1.f)));
+ auto* lhs = create<ast::TypeConstructorExpression>(&vec3, vals);
+
+ auto* rhs = create<ast::IdentifierExpression>("mat");
+
+ td.RegisterVariableForTesting(var);
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(),
+ "mul(float3(1.00000000f, 1.00000000f, 1.00000000f), mat)");
+}
+
+TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::MatrixType mat3(&f32, 3, 3);
+
+ auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
+ auto* lhs = create<ast::IdentifierExpression>("mat");
+ auto* rhs = create<ast::IdentifierExpression>("mat");
+
+ td.RegisterVariableForTesting(var);
+
+ ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
+ EXPECT_EQ(result(), "mul(mat, mat)");
+}
+
TEST_F(HlslGeneratorImplTest_Binary, Logical_And) {
auto* left = create<ast::IdentifierExpression>("left");
auto* right = create<ast::IdentifierExpression>("right");