Add relational expression type determination.
This CL adds the type determination for each of the relation types in
the relational expression.
Bug: tint:5
Change-Id: I15e8dae2f90cc4a0f720692f5addb944b26811ec
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18847
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 9a81e19..762f885 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -30,10 +30,12 @@
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/regardless_statement.h"
+#include "src/ast/relational_expression.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
+#include "src/ast/type/bool_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
@@ -206,6 +208,9 @@
if (expr->IsMemberAccessor()) {
return DetermineMemberAccessor(expr->AsMemberAccessor());
}
+ if (expr->IsRelational()) {
+ return DetermineRelational(expr->AsRelational());
+ }
error_ = "unknown expression for type determination";
return false;
@@ -321,4 +326,79 @@
return false;
}
+bool TypeDeterminer::DetermineRelational(ast::RelationalExpression* expr) {
+ if (!DetermineResultType(expr->lhs()) || !DetermineResultType(expr->rhs())) {
+ return false;
+ }
+
+ // Result type matches first parameter type
+ if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
+ expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() ||
+ expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) {
+ expr->set_result_type(expr->lhs()->result_type());
+ return true;
+ }
+ // Result type is a scalar or vector of boolean type
+ if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() ||
+ expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
+ expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
+ auto bool_type =
+ ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+ auto param_type = expr->lhs()->result_type();
+ if (param_type->IsVector()) {
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+ bool_type, param_type->AsVector()->size())));
+ } else {
+ expr->set_result_type(bool_type);
+ }
+ return true;
+ }
+ if (expr->IsMultiply()) {
+ auto lhs_type = expr->lhs()->result_type();
+ auto rhs_type = expr->rhs()->result_type();
+
+ // Note, the ordering here matters. The later checks depend on the prior
+ // checks having been done.
+ if (lhs_type->IsMatrix() && rhs_type->IsMatrix()) {
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
+ lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(),
+ rhs_type->AsMatrix()->columns())));
+
+ } else if (lhs_type->IsMatrix() && rhs_type->IsVector()) {
+ auto mat = lhs_type->AsMatrix();
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+ mat->type(), mat->rows())));
+ } else if (lhs_type->IsVector() && rhs_type->IsMatrix()) {
+ auto mat = rhs_type->AsMatrix();
+ expr->set_result_type(
+ ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+ mat->type(), mat->columns())));
+ } else if (lhs_type->IsMatrix()) {
+ // matrix * scalar
+ expr->set_result_type(lhs_type);
+ } else if (rhs_type->IsMatrix()) {
+ // scalar * matrix
+ expr->set_result_type(rhs_type);
+ } else if (lhs_type->IsVector() && rhs_type->IsVector()) {
+ expr->set_result_type(lhs_type);
+ } else if (lhs_type->IsVector()) {
+ // Vector * scalar
+ expr->set_result_type(lhs_type);
+ } else if (rhs_type->IsVector()) {
+ // Scalar * vector
+ expr->set_result_type(rhs_type);
+ } else {
+ // Scalar * Scalar
+ expr->set_result_type(lhs_type);
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
} // namespace tint
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 34fcaf3..381a273 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -30,9 +30,10 @@
class CallExpression;
class CastExpression;
class ConstructorExpression;
+class Function;
class IdentifierExpression;
class MemberAccessorExpression;
-class Function;
+class RelationalExpression;
class Variable;
} // namespace ast
@@ -81,6 +82,7 @@
bool DetermineConstructor(ast::ConstructorExpression* expr);
bool DetermineIdentifier(ast::IdentifierExpression* expr);
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
+ bool DetermineRelational(ast::RelationalExpression* expr);
Context& ctx_;
std::string error_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 486246b..272e982 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -35,12 +35,14 @@
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/regardless_statement.h"
+#include "src/ast/relational_expression.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/struct.h"
#include "src/ast/struct_member.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
+#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
@@ -755,5 +757,479 @@
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
}
+using Expr_Relational_BitwiseTest = testing::TestWithParam<ast::Relation>;
+TEST_P(Expr_Relational_BitwiseTest, Scalar) {
+ auto op = GetParam();
+
+ ast::type::I32Type i32;
+
+ auto var =
+ std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ op, std::make_unique<ast::IdentifierExpression>("val"),
+ std::make_unique<ast::IdentifierExpression>("val"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ EXPECT_TRUE(expr.result_type()->IsI32());
+}
+
+TEST_P(Expr_Relational_BitwiseTest, Vector) {
+ auto op = GetParam();
+
+ ast::type::I32Type i32;
+ ast::type::VectorType vec3(&i32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ op, std::make_unique<ast::IdentifierExpression>("val"),
+ std::make_unique<ast::IdentifierExpression>("val"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsVector());
+ EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsI32());
+ EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+ Expr_Relational_BitwiseTest,
+ testing::Values(ast::Relation::kAnd,
+ ast::Relation::kOr,
+ ast::Relation::kXor,
+ ast::Relation::kShiftLeft,
+ ast::Relation::kShiftRight,
+ ast::Relation::kShiftRightArith,
+ ast::Relation::kAdd,
+ ast::Relation::kSubtract,
+ ast::Relation::kDivide,
+ ast::Relation::kModulo));
+
+using Expr_Relational_LogicalTest = testing::TestWithParam<ast::Relation>;
+TEST_P(Expr_Relational_LogicalTest, Scalar) {
+ auto op = GetParam();
+
+ ast::type::BoolType bool_type;
+
+ auto var = std::make_unique<ast::Variable>("val", ast::StorageClass::kNone,
+ &bool_type);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ op, std::make_unique<ast::IdentifierExpression>("val"),
+ std::make_unique<ast::IdentifierExpression>("val"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ EXPECT_TRUE(expr.result_type()->IsBool());
+}
+
+TEST_P(Expr_Relational_LogicalTest, Vector) {
+ auto op = GetParam();
+
+ ast::type::BoolType bool_type;
+ ast::type::VectorType vec3(&bool_type, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ op, std::make_unique<ast::IdentifierExpression>("val"),
+ std::make_unique<ast::IdentifierExpression>("val"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsVector());
+ EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool());
+ EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+ Expr_Relational_LogicalTest,
+ testing::Values(ast::Relation::kLogicalAnd,
+ ast::Relation::kLogicalOr));
+
+using Expr_Relational_CompareTest = testing::TestWithParam<ast::Relation>;
+TEST_P(Expr_Relational_CompareTest, Scalar) {
+ auto op = GetParam();
+
+ ast::type::I32Type i32;
+
+ auto var =
+ std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ op, std::make_unique<ast::IdentifierExpression>("val"),
+ std::make_unique<ast::IdentifierExpression>("val"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ EXPECT_TRUE(expr.result_type()->IsBool());
+}
+
+TEST_P(Expr_Relational_CompareTest, Vector) {
+ auto op = GetParam();
+
+ ast::type::I32Type i32;
+ ast::type::VectorType vec3(&i32, 3);
+
+ auto var =
+ std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ op, std::make_unique<ast::IdentifierExpression>("val"),
+ std::make_unique<ast::IdentifierExpression>("val"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsVector());
+ EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool());
+ EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+ Expr_Relational_CompareTest,
+ testing::Values(ast::Relation::kEqual,
+ ast::Relation::kNotEqual,
+ ast::Relation::kLessThan,
+ ast::Relation::kGreaterThan,
+ ast::Relation::kLessThanEqual,
+ ast::Relation::kGreaterThanEqual));
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Scalar) {
+ ast::type::I32Type i32;
+
+ auto var =
+ std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(var));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("val"),
+ std::make_unique<ast::IdentifierExpression>("val"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ EXPECT_TRUE(expr.result_type()->IsI32());
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Scalar) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto scalar =
+ std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
+ auto vector = std::make_unique<ast::Variable>(
+ "vector", ast::StorageClass::kNone, &vec3);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(scalar));
+ m.AddGlobalVariable(std::move(vector));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("vector"),
+ std::make_unique<ast::IdentifierExpression>("scalar"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsVector());
+ EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+ EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Vector) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto scalar =
+ std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
+ auto vector = std::make_unique<ast::Variable>(
+ "vector", ast::StorageClass::kNone, &vec3);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(scalar));
+ m.AddGlobalVariable(std::move(vector));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("scalar"),
+ std::make_unique<ast::IdentifierExpression>("vector"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsVector());
+ EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+ EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Vector) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+
+ auto vector = std::make_unique<ast::Variable>(
+ "vector", ast::StorageClass::kNone, &vec3);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(vector));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("vector"),
+ std::make_unique<ast::IdentifierExpression>("vector"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsVector());
+ EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+ EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Scalar) {
+ ast::type::F32Type f32;
+ ast::type::MatrixType mat3x2(&f32, 3, 2);
+
+ auto scalar =
+ std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
+ auto matrix = std::make_unique<ast::Variable>(
+ "matrix", ast::StorageClass::kNone, &mat3x2);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(scalar));
+ m.AddGlobalVariable(std::move(matrix));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("matrix"),
+ std::make_unique<ast::IdentifierExpression>("scalar"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsMatrix());
+
+ auto mat = expr.result_type()->AsMatrix();
+ EXPECT_TRUE(mat->type()->IsF32());
+ EXPECT_EQ(mat->rows(), 3);
+ EXPECT_EQ(mat->columns(), 2);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Matrix) {
+ ast::type::F32Type f32;
+ ast::type::MatrixType mat3x2(&f32, 3, 2);
+
+ auto scalar =
+ std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
+ auto matrix = std::make_unique<ast::Variable>(
+ "matrix", ast::StorageClass::kNone, &mat3x2);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(scalar));
+ m.AddGlobalVariable(std::move(matrix));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("scalar"),
+ std::make_unique<ast::IdentifierExpression>("matrix"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsMatrix());
+
+ auto mat = expr.result_type()->AsMatrix();
+ EXPECT_TRUE(mat->type()->IsF32());
+ EXPECT_EQ(mat->rows(), 3);
+ EXPECT_EQ(mat->columns(), 2);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Vector) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 2);
+ ast::type::MatrixType mat3x2(&f32, 3, 2);
+
+ auto vector = std::make_unique<ast::Variable>(
+ "vector", ast::StorageClass::kNone, &vec3);
+ auto matrix = std::make_unique<ast::Variable>(
+ "matrix", ast::StorageClass::kNone, &mat3x2);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(vector));
+ m.AddGlobalVariable(std::move(matrix));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("matrix"),
+ std::make_unique<ast::IdentifierExpression>("vector"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsVector());
+ EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+ EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Matrix) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec3(&f32, 3);
+ ast::type::MatrixType mat3x2(&f32, 3, 2);
+
+ auto vector = std::make_unique<ast::Variable>(
+ "vector", ast::StorageClass::kNone, &vec3);
+ auto matrix = std::make_unique<ast::Variable>(
+ "matrix", ast::StorageClass::kNone, &mat3x2);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(vector));
+ m.AddGlobalVariable(std::move(matrix));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("vector"),
+ std::make_unique<ast::IdentifierExpression>("matrix"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsVector());
+ EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+ EXPECT_EQ(expr.result_type()->AsVector()->size(), 2);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Matrix) {
+ ast::type::F32Type f32;
+ ast::type::MatrixType mat4x3(&f32, 4, 3);
+ ast::type::MatrixType mat3x4(&f32, 3, 4);
+
+ auto matrix1 = std::make_unique<ast::Variable>(
+ "mat4x3", ast::StorageClass::kNone, &mat4x3);
+ auto matrix2 = std::make_unique<ast::Variable>(
+ "mat3x4", ast::StorageClass::kNone, &mat3x4);
+
+ Context ctx;
+ TypeDeterminer td(&ctx);
+
+ ast::Module m;
+ m.AddGlobalVariable(std::move(matrix1));
+ m.AddGlobalVariable(std::move(matrix2));
+
+ // Register the global
+ ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+ ast::RelationalExpression expr(
+ ast::Relation::kMultiply,
+ std::make_unique<ast::IdentifierExpression>("mat4x3"),
+ std::make_unique<ast::IdentifierExpression>("mat3x4"));
+
+ ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+ ASSERT_NE(expr.result_type(), nullptr);
+ ASSERT_TRUE(expr.result_type()->IsMatrix());
+
+ auto mat = expr.result_type()->AsMatrix();
+ EXPECT_TRUE(mat->type()->IsF32());
+ EXPECT_EQ(mat->rows(), 4);
+ EXPECT_EQ(mat->columns(), 4);
+}
+
} // namespace
} // namespace tint