Add relational expression type determination.

This CL adds the type determination for each of the relation types in
the relational expression.

Bug: tint:5
Change-Id: I15e8dae2f90cc4a0f720692f5addb944b26811ec
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18847
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 9a81e19..762f885 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -30,10 +30,12 @@
 #include "src/ast/loop_statement.h"
 #include "src/ast/member_accessor_expression.h"
 #include "src/ast/regardless_statement.h"
+#include "src/ast/relational_expression.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/switch_statement.h"
 #include "src/ast/type/array_type.h"
+#include "src/ast/type/bool_type.h"
 #include "src/ast/type/matrix_type.h"
 #include "src/ast/type/struct_type.h"
 #include "src/ast/type/vector_type.h"
@@ -206,6 +208,9 @@
   if (expr->IsMemberAccessor()) {
     return DetermineMemberAccessor(expr->AsMemberAccessor());
   }
+  if (expr->IsRelational()) {
+    return DetermineRelational(expr->AsRelational());
+  }
 
   error_ = "unknown expression for type determination";
   return false;
@@ -321,4 +326,79 @@
   return false;
 }
 
+bool TypeDeterminer::DetermineRelational(ast::RelationalExpression* expr) {
+  if (!DetermineResultType(expr->lhs()) || !DetermineResultType(expr->rhs())) {
+    return false;
+  }
+
+  // Result type matches first parameter type
+  if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
+      expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() ||
+      expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) {
+    expr->set_result_type(expr->lhs()->result_type());
+    return true;
+  }
+  // Result type is a scalar or vector of boolean type
+  if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() ||
+      expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
+      expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
+    auto bool_type =
+        ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+    auto param_type = expr->lhs()->result_type();
+    if (param_type->IsVector()) {
+      expr->set_result_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+              bool_type, param_type->AsVector()->size())));
+    } else {
+      expr->set_result_type(bool_type);
+    }
+    return true;
+  }
+  if (expr->IsMultiply()) {
+    auto lhs_type = expr->lhs()->result_type();
+    auto rhs_type = expr->rhs()->result_type();
+
+    // Note, the ordering here matters. The later checks depend on the prior
+    // checks having been done.
+    if (lhs_type->IsMatrix() && rhs_type->IsMatrix()) {
+      expr->set_result_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
+              lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(),
+              rhs_type->AsMatrix()->columns())));
+
+    } else if (lhs_type->IsMatrix() && rhs_type->IsVector()) {
+      auto mat = lhs_type->AsMatrix();
+      expr->set_result_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+              mat->type(), mat->rows())));
+    } else if (lhs_type->IsVector() && rhs_type->IsMatrix()) {
+      auto mat = rhs_type->AsMatrix();
+      expr->set_result_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+              mat->type(), mat->columns())));
+    } else if (lhs_type->IsMatrix()) {
+      // matrix * scalar
+      expr->set_result_type(lhs_type);
+    } else if (rhs_type->IsMatrix()) {
+      // scalar * matrix
+      expr->set_result_type(rhs_type);
+    } else if (lhs_type->IsVector() && rhs_type->IsVector()) {
+      expr->set_result_type(lhs_type);
+    } else if (lhs_type->IsVector()) {
+      // Vector * scalar
+      expr->set_result_type(lhs_type);
+    } else if (rhs_type->IsVector()) {
+      // Scalar * vector
+      expr->set_result_type(rhs_type);
+    } else {
+      // Scalar * Scalar
+      expr->set_result_type(lhs_type);
+    }
+
+    return true;
+  }
+
+  return false;
+}
+
 }  // namespace tint
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 34fcaf3..381a273 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -30,9 +30,10 @@
 class CallExpression;
 class CastExpression;
 class ConstructorExpression;
+class Function;
 class IdentifierExpression;
 class MemberAccessorExpression;
-class Function;
+class RelationalExpression;
 class Variable;
 
 }  // namespace ast
@@ -81,6 +82,7 @@
   bool DetermineConstructor(ast::ConstructorExpression* expr);
   bool DetermineIdentifier(ast::IdentifierExpression* expr);
   bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
+  bool DetermineRelational(ast::RelationalExpression* expr);
 
   Context& ctx_;
   std::string error_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 486246b..272e982 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -35,12 +35,14 @@
 #include "src/ast/loop_statement.h"
 #include "src/ast/member_accessor_expression.h"
 #include "src/ast/regardless_statement.h"
+#include "src/ast/relational_expression.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/struct.h"
 #include "src/ast/struct_member.h"
 #include "src/ast/switch_statement.h"
 #include "src/ast/type/array_type.h"
+#include "src/ast/type/bool_type.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/matrix_type.h"
@@ -755,5 +757,479 @@
   EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
 }
 
+using Expr_Relational_BitwiseTest = testing::TestWithParam<ast::Relation>;
+TEST_P(Expr_Relational_BitwiseTest, Scalar) {
+  auto op = GetParam();
+
+  ast::type::I32Type i32;
+
+  auto var =
+      std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      op, std::make_unique<ast::IdentifierExpression>("val"),
+      std::make_unique<ast::IdentifierExpression>("val"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  EXPECT_TRUE(expr.result_type()->IsI32());
+}
+
+TEST_P(Expr_Relational_BitwiseTest, Vector) {
+  auto op = GetParam();
+
+  ast::type::I32Type i32;
+  ast::type::VectorType vec3(&i32, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      op, std::make_unique<ast::IdentifierExpression>("val"),
+      std::make_unique<ast::IdentifierExpression>("val"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsVector());
+  EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsI32());
+  EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+                         Expr_Relational_BitwiseTest,
+                         testing::Values(ast::Relation::kAnd,
+                                         ast::Relation::kOr,
+                                         ast::Relation::kXor,
+                                         ast::Relation::kShiftLeft,
+                                         ast::Relation::kShiftRight,
+                                         ast::Relation::kShiftRightArith,
+                                         ast::Relation::kAdd,
+                                         ast::Relation::kSubtract,
+                                         ast::Relation::kDivide,
+                                         ast::Relation::kModulo));
+
+using Expr_Relational_LogicalTest = testing::TestWithParam<ast::Relation>;
+TEST_P(Expr_Relational_LogicalTest, Scalar) {
+  auto op = GetParam();
+
+  ast::type::BoolType bool_type;
+
+  auto var = std::make_unique<ast::Variable>("val", ast::StorageClass::kNone,
+                                             &bool_type);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      op, std::make_unique<ast::IdentifierExpression>("val"),
+      std::make_unique<ast::IdentifierExpression>("val"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  EXPECT_TRUE(expr.result_type()->IsBool());
+}
+
+TEST_P(Expr_Relational_LogicalTest, Vector) {
+  auto op = GetParam();
+
+  ast::type::BoolType bool_type;
+  ast::type::VectorType vec3(&bool_type, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      op, std::make_unique<ast::IdentifierExpression>("val"),
+      std::make_unique<ast::IdentifierExpression>("val"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsVector());
+  EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool());
+  EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+                         Expr_Relational_LogicalTest,
+                         testing::Values(ast::Relation::kLogicalAnd,
+                                         ast::Relation::kLogicalOr));
+
+using Expr_Relational_CompareTest = testing::TestWithParam<ast::Relation>;
+TEST_P(Expr_Relational_CompareTest, Scalar) {
+  auto op = GetParam();
+
+  ast::type::I32Type i32;
+
+  auto var =
+      std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      op, std::make_unique<ast::IdentifierExpression>("val"),
+      std::make_unique<ast::IdentifierExpression>("val"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  EXPECT_TRUE(expr.result_type()->IsBool());
+}
+
+TEST_P(Expr_Relational_CompareTest, Vector) {
+  auto op = GetParam();
+
+  ast::type::I32Type i32;
+  ast::type::VectorType vec3(&i32, 3);
+
+  auto var =
+      std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      op, std::make_unique<ast::IdentifierExpression>("val"),
+      std::make_unique<ast::IdentifierExpression>("val"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsVector());
+  EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool());
+  EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+                         Expr_Relational_CompareTest,
+                         testing::Values(ast::Relation::kEqual,
+                                         ast::Relation::kNotEqual,
+                                         ast::Relation::kLessThan,
+                                         ast::Relation::kGreaterThan,
+                                         ast::Relation::kLessThanEqual,
+                                         ast::Relation::kGreaterThanEqual));
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Scalar) {
+  ast::type::I32Type i32;
+
+  auto var =
+      std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("val"),
+      std::make_unique<ast::IdentifierExpression>("val"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  EXPECT_TRUE(expr.result_type()->IsI32());
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Scalar) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto scalar =
+      std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
+  auto vector = std::make_unique<ast::Variable>(
+      "vector", ast::StorageClass::kNone, &vec3);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(scalar));
+  m.AddGlobalVariable(std::move(vector));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("vector"),
+      std::make_unique<ast::IdentifierExpression>("scalar"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsVector());
+  EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+  EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Vector) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto scalar =
+      std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
+  auto vector = std::make_unique<ast::Variable>(
+      "vector", ast::StorageClass::kNone, &vec3);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(scalar));
+  m.AddGlobalVariable(std::move(vector));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("scalar"),
+      std::make_unique<ast::IdentifierExpression>("vector"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsVector());
+  EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+  EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Vector) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto vector = std::make_unique<ast::Variable>(
+      "vector", ast::StorageClass::kNone, &vec3);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(vector));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("vector"),
+      std::make_unique<ast::IdentifierExpression>("vector"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsVector());
+  EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+  EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Scalar) {
+  ast::type::F32Type f32;
+  ast::type::MatrixType mat3x2(&f32, 3, 2);
+
+  auto scalar =
+      std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
+  auto matrix = std::make_unique<ast::Variable>(
+      "matrix", ast::StorageClass::kNone, &mat3x2);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(scalar));
+  m.AddGlobalVariable(std::move(matrix));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("matrix"),
+      std::make_unique<ast::IdentifierExpression>("scalar"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsMatrix());
+
+  auto mat = expr.result_type()->AsMatrix();
+  EXPECT_TRUE(mat->type()->IsF32());
+  EXPECT_EQ(mat->rows(), 3);
+  EXPECT_EQ(mat->columns(), 2);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Matrix) {
+  ast::type::F32Type f32;
+  ast::type::MatrixType mat3x2(&f32, 3, 2);
+
+  auto scalar =
+      std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
+  auto matrix = std::make_unique<ast::Variable>(
+      "matrix", ast::StorageClass::kNone, &mat3x2);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(scalar));
+  m.AddGlobalVariable(std::move(matrix));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("scalar"),
+      std::make_unique<ast::IdentifierExpression>("matrix"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsMatrix());
+
+  auto mat = expr.result_type()->AsMatrix();
+  EXPECT_TRUE(mat->type()->IsF32());
+  EXPECT_EQ(mat->rows(), 3);
+  EXPECT_EQ(mat->columns(), 2);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Vector) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 2);
+  ast::type::MatrixType mat3x2(&f32, 3, 2);
+
+  auto vector = std::make_unique<ast::Variable>(
+      "vector", ast::StorageClass::kNone, &vec3);
+  auto matrix = std::make_unique<ast::Variable>(
+      "matrix", ast::StorageClass::kNone, &mat3x2);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(vector));
+  m.AddGlobalVariable(std::move(matrix));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("matrix"),
+      std::make_unique<ast::IdentifierExpression>("vector"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsVector());
+  EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+  EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Matrix) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::MatrixType mat3x2(&f32, 3, 2);
+
+  auto vector = std::make_unique<ast::Variable>(
+      "vector", ast::StorageClass::kNone, &vec3);
+  auto matrix = std::make_unique<ast::Variable>(
+      "matrix", ast::StorageClass::kNone, &mat3x2);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(vector));
+  m.AddGlobalVariable(std::move(matrix));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("vector"),
+      std::make_unique<ast::IdentifierExpression>("matrix"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsVector());
+  EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
+  EXPECT_EQ(expr.result_type()->AsVector()->size(), 2);
+}
+
+TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Matrix) {
+  ast::type::F32Type f32;
+  ast::type::MatrixType mat4x3(&f32, 4, 3);
+  ast::type::MatrixType mat3x4(&f32, 3, 4);
+
+  auto matrix1 = std::make_unique<ast::Variable>(
+      "mat4x3", ast::StorageClass::kNone, &mat4x3);
+  auto matrix2 = std::make_unique<ast::Variable>(
+      "mat3x4", ast::StorageClass::kNone, &mat3x4);
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(matrix1));
+  m.AddGlobalVariable(std::move(matrix2));
+
+  // Register the global
+  ASSERT_TRUE(td.Determine(&m)) << td.error();
+
+  ast::RelationalExpression expr(
+      ast::Relation::kMultiply,
+      std::make_unique<ast::IdentifierExpression>("mat4x3"),
+      std::make_unique<ast::IdentifierExpression>("mat3x4"));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+  ASSERT_NE(expr.result_type(), nullptr);
+  ASSERT_TRUE(expr.result_type()->IsMatrix());
+
+  auto mat = expr.result_type()->AsMatrix();
+  EXPECT_TRUE(mat->type()->IsF32());
+  EXPECT_EQ(mat->rows(), 4);
+  EXPECT_EQ(mat->columns(), 4);
+}
+
 }  // namespace
 }  // namespace tint