Add type determination for unary method.

This CL adds the type determination for the unary method expression.

Bug: tint:5
Change-Id: I9f94a79b9715cf74e37c74eb1a612ca84b3c241f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18849
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 8f07dc4..c8d4952 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -36,11 +36,13 @@
 #include "src/ast/switch_statement.h"
 #include "src/ast/type/array_type.h"
 #include "src/ast/type/bool_type.h"
+#include "src/ast/type/f32_type.h"
 #include "src/ast/type/matrix_type.h"
 #include "src/ast/type/struct_type.h"
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/ast/unary_derivative_expression.h"
+#include "src/ast/unary_method_expression.h"
 #include "src/ast/unless_statement.h"
 #include "src/ast/variable_decl_statement.h"
 
@@ -215,6 +217,9 @@
   if (expr->IsUnaryDerivative()) {
     return DetermineUnaryDerivative(expr->AsUnaryDerivative());
   }
+  if (expr->IsUnaryMethod()) {
+    return DetermineUnaryMethod(expr->AsUnaryMethod());
+  }
 
   error_ = "unknown expression for type determination";
   return false;
@@ -415,4 +420,66 @@
   return true;
 }
 
+bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
+  for (const auto& param : expr->params()) {
+    if (!DetermineResultType(param.get())) {
+      return false;
+    }
+  }
+
+  switch (expr->op()) {
+    case ast::UnaryMethod::kAny:
+    case ast::UnaryMethod::kAll: {
+      expr->set_result_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()));
+      break;
+    }
+    case ast::UnaryMethod::kIsNan:
+    case ast::UnaryMethod::kIsInf:
+    case ast::UnaryMethod::kIsFinite:
+    case ast::UnaryMethod::kIsNormal: {
+      if (expr->params().empty()) {
+        error_ = "incorrect number of parameters";
+        return false;
+      }
+
+      auto bool_type =
+          ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
+      auto param_type = expr->params()[0]->result_type();
+      if (param_type->IsVector()) {
+        expr->set_result_type(
+            ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
+                bool_type, param_type->AsVector()->size())));
+      } else {
+        expr->set_result_type(bool_type);
+      }
+      break;
+    }
+    case ast::UnaryMethod::kDot: {
+      expr->set_result_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()));
+      break;
+    }
+    case ast::UnaryMethod::kOuterProduct: {
+      if (expr->params().size() != 2) {
+        error_ = "incorrect number of parameters for outer product";
+        return false;
+      }
+      auto param0_type = expr->params()[0]->result_type();
+      auto param1_type = expr->params()[1]->result_type();
+      if (!param0_type->IsVector() || !param1_type->IsVector()) {
+        error_ = "invalid parameter type for outer product";
+        return false;
+      }
+      expr->set_result_type(
+          ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
+              ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()),
+              param0_type->AsVector()->size(),
+              param1_type->AsVector()->size())));
+      break;
+    }
+  }
+  return true;
+}
+
 }  // namespace tint
diff --git a/src/type_determiner.h b/src/type_determiner.h
index de6a188..126b755 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -35,6 +35,7 @@
 class MemberAccessorExpression;
 class RelationalExpression;
 class UnaryDerivativeExpression;
+class UnaryMethodExpression;
 class Variable;
 
 }  // namespace ast
@@ -85,6 +86,7 @@
   bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
   bool DetermineRelational(ast::RelationalExpression* expr);
   bool DetermineUnaryDerivative(ast::UnaryDerivativeExpression* expr);
+  bool DetermineUnaryMethod(ast::UnaryMethodExpression* expr);
 
   Context& ctx_;
   std::string error_;
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index a416916..1d63956 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -50,6 +50,7 @@
 #include "src/ast/type/vector_type.h"
 #include "src/ast/type_constructor_expression.h"
 #include "src/ast/unary_derivative_expression.h"
+#include "src/ast/unary_method_expression.h"
 #include "src/ast/unless_statement.h"
 #include "src/ast/variable_decl_statement.h"
 
@@ -1268,5 +1269,164 @@
                                          ast::UnaryDerivative::kDpdy,
                                          ast::UnaryDerivative::kFwidth));
 
+using UnaryMethodExpressionBoolTest = testing::TestWithParam<ast::UnaryMethod>;
+TEST_P(UnaryMethodExpressionBoolTest, Expr_UnaryMethod_Any) {
+  auto op = GetParam();
+
+  ast::type::BoolType bool_type;
+  ast::type::VectorType vec3(&bool_type, 3);
+
+  auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
+                                             &vec3);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  ast::ExpressionList params;
+  params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+
+  ast::UnaryMethodExpression exp(op, std::move(params));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  // Register the variable
+  EXPECT_TRUE(td.Determine(&m));
+
+  EXPECT_TRUE(td.DetermineResultType(&exp));
+  ASSERT_NE(exp.result_type(), nullptr);
+  EXPECT_TRUE(exp.result_type()->IsBool());
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+                         UnaryMethodExpressionBoolTest,
+                         testing::Values(ast::UnaryMethod::kAny,
+                                         ast::UnaryMethod::kAll));
+
+using UnaryMethodExpressionVecTest = testing::TestWithParam<ast::UnaryMethod>;
+TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Bool) {
+  auto op = GetParam();
+
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
+                                             &vec3);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  ast::ExpressionList params;
+  params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+
+  ast::UnaryMethodExpression exp(op, std::move(params));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  // Register the variable
+  EXPECT_TRUE(td.Determine(&m));
+
+  EXPECT_TRUE(td.DetermineResultType(&exp));
+  ASSERT_NE(exp.result_type(), nullptr);
+  ASSERT_TRUE(exp.result_type()->IsVector());
+  EXPECT_TRUE(exp.result_type()->AsVector()->type()->IsBool());
+  EXPECT_EQ(exp.result_type()->AsVector()->size(), 3);
+}
+TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Vec) {
+  auto op = GetParam();
+
+  ast::type::F32Type f32;
+
+  auto var =
+      std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  ast::ExpressionList params;
+  params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+
+  ast::UnaryMethodExpression exp(op, std::move(params));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  // Register the variable
+  EXPECT_TRUE(td.Determine(&m));
+
+  EXPECT_TRUE(td.DetermineResultType(&exp));
+  ASSERT_NE(exp.result_type(), nullptr);
+  EXPECT_TRUE(exp.result_type()->IsBool());
+}
+INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
+                         UnaryMethodExpressionVecTest,
+                         testing::Values(ast::UnaryMethod::kIsInf,
+                                         ast::UnaryMethod::kIsNan,
+                                         ast::UnaryMethod::kIsFinite,
+                                         ast::UnaryMethod::kIsNormal));
+
+TEST_F(TypeDeterminerTest, Expr_UnaryMethod_Dot) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+
+  auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
+                                             &vec3);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var));
+
+  ast::ExpressionList params;
+  params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+  params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
+
+  ast::UnaryMethodExpression exp(ast::UnaryMethod::kDot, std::move(params));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  // Register the variable
+  EXPECT_TRUE(td.Determine(&m));
+
+  EXPECT_TRUE(td.DetermineResultType(&exp));
+  ASSERT_NE(exp.result_type(), nullptr);
+  EXPECT_TRUE(exp.result_type()->IsF32());
+}
+
+TEST_F(TypeDeterminerTest, Expr_UnaryMethod_OuterProduct) {
+  ast::type::F32Type f32;
+  ast::type::VectorType vec3(&f32, 3);
+  ast::type::VectorType vec2(&f32, 2);
+
+  auto var1 =
+      std::make_unique<ast::Variable>("v3", ast::StorageClass::kNone, &vec3);
+  auto var2 =
+      std::make_unique<ast::Variable>("v2", ast::StorageClass::kNone, &vec2);
+
+  ast::Module m;
+  m.AddGlobalVariable(std::move(var1));
+  m.AddGlobalVariable(std::move(var2));
+
+  ast::ExpressionList params;
+  params.push_back(std::make_unique<ast::IdentifierExpression>("v3"));
+  params.push_back(std::make_unique<ast::IdentifierExpression>("v2"));
+
+  ast::UnaryMethodExpression exp(ast::UnaryMethod::kOuterProduct,
+                                 std::move(params));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+
+  // Register the variable
+  EXPECT_TRUE(td.Determine(&m));
+
+  EXPECT_TRUE(td.DetermineResultType(&exp));
+  ASSERT_NE(exp.result_type(), nullptr);
+  ASSERT_TRUE(exp.result_type()->IsMatrix());
+  auto mat = exp.result_type()->AsMatrix();
+  EXPECT_TRUE(mat->type()->IsF32());
+  EXPECT_EQ(mat->rows(), 3);
+  EXPECT_EQ(mat->columns(), 2);
+}
+
 }  // namespace
 }  // namespace tint