[validation] Checks if recursions exist

This CL validates the following rule. ie. As functions must be defined before use (v-0005), self-recursion is only case that has to be invalidated.
v-0004: Recursions are not allowed.

Bug: tint: 6
Change-Id: Icfb040907c5ea0abb6359dade74dcfc30a0db7d9
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26980
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/validator_function_test.cc b/src/validator_function_test.cc
index e23216a..fb484c7 100644
--- a/src/validator_function_test.cc
+++ b/src/validator_function_test.cc
@@ -16,6 +16,7 @@
 
 #include "gtest/gtest.h"
 #include "spirv/unified1/GLSL.std.450.h"
+#include "src/ast/call_statement.h"
 #include "src/ast/return_statement.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/sint_literal.h"
@@ -188,5 +189,56 @@
   EXPECT_EQ(v.error(), "12:34: v-0016: function names must be unique 'func'");
 }
 
+TEST_F(ValidateFunctionTest, RecursionIsNotAllowed_Fail) {
+  // fn func() -> void {func(); return; }
+  ast::type::F32Type f32;
+  ast::type::VoidType void_type;
+  ast::ExpressionList call_params;
+  auto call_expr = std::make_unique<ast::CallExpression>(
+      Source{12, 34}, std::make_unique<ast::IdentifierExpression>("func"),
+      std::move(call_params));
+  ast::VariableList params0;
+  auto func0 =
+      std::make_unique<ast::Function>("func", std::move(params0), &f32);
+  auto body0 = std::make_unique<ast::BlockStatement>();
+  body0->append(std::make_unique<ast::CallStatement>(std::move(call_expr)));
+  body0->append(std::make_unique<ast::ReturnStatement>());
+  func0->set_body(std::move(body0));
+  mod()->AddFunction(std::move(func0));
+
+  EXPECT_TRUE(td()->Determine()) << td()->error();
+  tint::ValidatorImpl v;
+  EXPECT_FALSE(v.Validate(mod())) << v.error();
+  EXPECT_EQ(v.error(), "12:34: v-0004: recursion is not allowed: 'func'");
+}
+
+TEST_F(ValidateFunctionTest, RecursionIsNotAllowedExpr_Fail) {
+  // fn func() -> i32 {var a: i32 = func(); return 2; }
+  ast::type::I32Type i32;
+  auto var =
+      std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &i32);
+  ast::ExpressionList call_params;
+  auto call_expr = std::make_unique<ast::CallExpression>(
+      Source{12, 34}, std::make_unique<ast::IdentifierExpression>("func"),
+      std::move(call_params));
+  var->set_constructor(std::move(call_expr));
+  ast::VariableList params0;
+  auto func0 =
+      std::make_unique<ast::Function>("func", std::move(params0), &i32);
+  auto body0 = std::make_unique<ast::BlockStatement>();
+  body0->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+  auto return_expr = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::SintLiteral>(&i32, 2));
+
+  body0->append(std::make_unique<ast::ReturnStatement>(std::move(return_expr)));
+  func0->set_body(std::move(body0));
+  mod()->AddFunction(std::move(func0));
+
+  EXPECT_TRUE(td()->Determine()) << td()->error();
+  tint::ValidatorImpl v;
+  EXPECT_FALSE(v.Validate(mod())) << v.error();
+  EXPECT_EQ(v.error(), "12:34: v-0004: recursion is not allowed: 'func'");
+}
+
 }  // namespace
 }  // namespace tint
diff --git a/src/validator_impl.cc b/src/validator_impl.cc
index 672948d..d978cc7 100644
--- a/src/validator_impl.cc
+++ b/src/validator_impl.cc
@@ -13,7 +13,9 @@
 // limitations under the License.
 
 #include "src/validator_impl.h"
+#include "src/ast/call_statement.h"
 #include "src/ast/function.h"
+#include "src/ast/intrinsic.h"
 #include "src/ast/type/void_type.h"
 #include "src/ast/variable_decl_statement.h"
 
@@ -145,7 +147,13 @@
     return false;
   }
   if (stmt->IsVariableDecl()) {
-    return ValidateDeclStatement(stmt->AsVariableDecl());
+    auto* v = stmt->AsVariableDecl();
+    bool constructor_valid =
+        v->variable()->has_constructor()
+            ? ValidateExpression(v->variable()->constructor())
+            : true;
+
+    return constructor_valid && ValidateDeclStatement(stmt->AsVariableDecl());
   }
   if (stmt->IsAssign()) {
     return ValidateAssign(stmt->AsAssign());
@@ -153,6 +161,44 @@
   if (stmt->IsReturn()) {
     return ValidateReturnStatement(stmt->AsReturn());
   }
+  if (stmt->IsCall()) {
+    return ValidateCallExpr(stmt->AsCall()->expr());
+  }
+  return true;
+}
+
+bool ValidatorImpl::ValidateCallExpr(const ast::CallExpression* expr) {
+  if (!expr) {
+    // TODO(sarahM0): Here and other Validate.*: figure out whether return false
+    // or true
+    return false;
+  }
+
+  if (expr->func()->IsIdentifier()) {
+    auto* ident = expr->func()->AsIdentifier();
+    auto func_name = ident->name();
+    if (ident->has_path()) {
+      // TODO(sarahM0): validate import statements
+    } else if (ast::intrinsic::IsIntrinsic(ident->name())) {
+      // TODO(sarahM0): validate intrinsics - tied with type-determiner
+    } else {
+      if (!function_stack_.has(func_name)) {
+        set_error(expr->source(),
+                  "v-0005: function must be declared before use: '" +
+                      func_name + "'");
+        return false;
+      }
+      if (func_name == current_function_->name()) {
+        set_error(expr->source(),
+                  "v-0004: recursion is not allowed: '" + func_name + "'");
+        return false;
+      }
+    }
+  } else {
+    set_error(expr->source(), "Invalid function call expression");
+    return false;
+  }
+
   return true;
 }
 
@@ -218,6 +264,9 @@
     return ValidateIdentifier(expr->AsIdentifier());
   }
 
+  if (expr->IsCall()) {
+    return ValidateCallExpr(expr->AsCall());
+  }
   return true;
 }
 
diff --git a/src/validator_impl.h b/src/validator_impl.h
index 6f0488e..5ed21cf 100644
--- a/src/validator_impl.h
+++ b/src/validator_impl.h
@@ -19,6 +19,7 @@
 #include <unordered_map>
 
 #include "src/ast/assignment_statement.h"
+#include "src/ast/call_expression.h"
 #include "src/ast/expression.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/module.h"
@@ -77,7 +78,7 @@
   bool CheckImports(const ast::Module* module);
   /// Validates an expression
   /// @param expr the expression to check
-  /// @return true if the expresssion is valid
+  /// @return true if the expression is valid
   bool ValidateExpression(const ast::Expression* expr);
   /// Validates v-0006:Variables must be defined before use
   /// @param ident the identifer to check if its in the scope
@@ -100,6 +101,10 @@
   /// @param ret the return statement to check
   /// @returns true if function return type matches the return statement type
   bool ValidateReturnStatement(const ast::ReturnStatement* ret);
+  /// Validates function calls
+  /// @param expr the call to validate
+  /// @returns true if successful
+  bool ValidateCallExpr(const ast::CallExpression* expr);
 
  private:
   std::string error_;
diff --git a/src/validator_test.cc b/src/validator_test.cc
index 853fe6e..1df7170 100644
--- a/src/validator_test.cc
+++ b/src/validator_test.cc
@@ -690,27 +690,32 @@
   EXPECT_TRUE(v.Validate(mod())) << v.error();
 }
 
-TEST_F(ValidatorTest, DISABLED_RecursionIsNotAllowed_Fail) {
-  // fn func() -> void {func(); return; }
-  ast::type::F32Type f32;
-  ast::type::VoidType void_type;
-  ast::ExpressionList call_params;
-  auto call_expr = std::make_unique<ast::CallExpression>(
-      Source{12, 34}, std::make_unique<ast::IdentifierExpression>("func"),
-      std::move(call_params));
-  ast::VariableList params0;
-  auto func0 =
-      std::make_unique<ast::Function>("func", std::move(params0), &f32);
-  auto body0 = std::make_unique<ast::BlockStatement>();
-  body0->append(std::make_unique<ast::CallStatement>(std::move(call_expr)));
-  body0->append(std::make_unique<ast::ReturnStatement>());
-  func0->set_body(std::move(body0));
-  mod()->AddFunction(std::move(func0));
+TEST_F(ValidatorTest, VariableDeclNoConstructor_Pass) {
+  // {
+  // var a :i32;
+  // a = 2;
+  // }
+  ast::type::I32Type i32;
+  auto var =
+      std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &i32);
 
-  EXPECT_TRUE(td()->Determine()) << td()->error();
+  td()->RegisterVariableForTesting(var.get());
+  auto lhs = std::make_unique<ast::IdentifierExpression>("a");
+  auto* lhs_ptr = lhs.get();
+  auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::SintLiteral>(&i32, 2));
+  auto* rhs_ptr = rhs.get();
+
+  auto body = std::make_unique<ast::BlockStatement>();
+  body->append(std::make_unique<ast::VariableDeclStatement>(std::move(var)));
+  body->append(std::make_unique<ast::AssignmentStatement>(
+      Source{12, 34}, std::move(lhs), std::move(rhs)));
+
+  EXPECT_TRUE(td()->DetermineStatements(body.get())) << td()->error();
+  ASSERT_NE(lhs_ptr->result_type(), nullptr);
+  ASSERT_NE(rhs_ptr->result_type(), nullptr);
   tint::ValidatorImpl v;
-  EXPECT_FALSE(v.Validate(mod())) << v.error();
-  EXPECT_EQ(v.error(), "12:34: v-0004: recursion is not allowed: 'func'");
+  EXPECT_TRUE(v.ValidateStatements(body.get())) << v.error();
 }
 
 }  // namespace