[validation] Validates if return statement type matches function return type

This CL checks if the return statement type matches the function return type

Bug: tint 6
Change-Id: I621d67086291c392b68261673a25c0e6caca71ae
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26860
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/validator.cc b/src/validator.cc
index 5af8414..3186d9e 100644
--- a/src/validator.cc
+++ b/src/validator.cc
@@ -13,7 +13,6 @@
 // limitations under the License.
 
 #include "src/validator.h"
-
 #include "src/validator_impl.h"
 
 namespace tint {
diff --git a/src/validator_function_test.cc b/src/validator_function_test.cc
index b220e45..e23216a 100644
--- a/src/validator_function_test.cc
+++ b/src/validator_function_test.cc
@@ -88,10 +88,7 @@
             "12:34: v-0002: function must end with a return statement");
 }
 
-TEST_F(ValidateFunctionTest,
-       DISABLED_FunctionTypeMustMatchReturnStatementType_pass) {
-  // TODO(sarahM0): remove DISABLED after implementing function type must match
-  // return type
+TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) {
   // fn func -> void { return; }
   ast::type::VoidType void_type;
   ast::VariableList params;
@@ -107,10 +104,7 @@
   EXPECT_TRUE(v.Validate(mod())) << v.error();
 }
 
-TEST_F(ValidateFunctionTest,
-       DISABLED_FunctionTypeMustMatchReturnStatementType_fail) {
-  // TODO(sarahM0): remove DISABLED after implementing function type must match
-  // return type
+TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) {
   // fn func -> void { return 2; }
   ast::type::VoidType void_type;
   ast::type::I32Type i32;
@@ -130,17 +124,13 @@
   tint::ValidatorImpl v;
   EXPECT_FALSE(v.Validate(mod()));
   // TODO(sarahM0): replace 000y with a rule number
-  EXPECT_EQ(
-      v.error(),
-      "12:34: v-000y: function type must match its return statement type");
+  EXPECT_EQ(v.error(),
+            "12:34: v-000y: return statement type must match its function "
+            "return type, returned '__i32', expected '__void'");
 }
 
-TEST_F(ValidateFunctionTest,
-       DISABLED_FunctionTypeMustMatchReturnStatementTypeF32_fail) {
-  // TODO(sarahM0): remove DISABLED after implementing function type must match
-  // return type
+TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) {
   // fn func -> f32 { return 2; }
-  ast::type::VoidType void_type;
   ast::type::I32Type i32;
   ast::type::F32Type f32;
   ast::VariableList params;
@@ -158,9 +148,9 @@
   tint::ValidatorImpl v;
   EXPECT_FALSE(v.Validate(mod()));
   // TODO(sarahM0): replace 000y with a rule number
-  EXPECT_EQ(
-      v.error(),
-      "12:34: v-000y: function type must match its return statement type");
+  EXPECT_EQ(v.error(),
+            "12:34: v-000y: return statement type must match its function "
+            "return type, returned '__i32', expected '__f32'");
 }
 
 TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) {
diff --git a/src/validator_impl.cc b/src/validator_impl.cc
index 0bbda6c..672948d 100644
--- a/src/validator_impl.cc
+++ b/src/validator_impl.cc
@@ -14,6 +14,7 @@
 
 #include "src/validator_impl.h"
 #include "src/ast/function.h"
+#include "src/ast/type/void_type.h"
 #include "src/ast/variable_decl_statement.h"
 
 namespace tint {
@@ -60,10 +61,11 @@
     }
 
     function_stack_.set(func_ptr->name(), func_ptr);
-
+    current_function_ = func_ptr;
     if (!ValidateFunction(func_ptr)) {
       return false;
     }
+    current_function_ = nullptr;
   }
   return true;
 }
@@ -87,6 +89,28 @@
   return true;
 }
 
+bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) {
+  // TODO(sarahM0): update this when this issue resolves:
+  // https://github.com/gpuweb/gpuweb/issues/996
+  ast::type::Type* func_type = current_function_->return_type();
+
+  ast::type::VoidType void_type;
+  auto* ret_type = ret->has_value()
+                       ? ret->value()->result_type()->UnwrapAliasPtrAlias()
+                       : &void_type;
+
+  if (func_type->type_name() != ret_type->type_name()) {
+    set_error(ret->source(),
+              "v-000y: return statement type must match its function return "
+              "type, returned '" +
+                  ret_type->type_name() + "', expected '" +
+                  func_type->type_name() + "'");
+    return false;
+  }
+
+  return true;
+}
+
 bool ValidatorImpl::ValidateStatements(const ast::BlockStatement* block) {
   if (!block) {
     return false;
@@ -126,6 +150,9 @@
   if (stmt->IsAssign()) {
     return ValidateAssign(stmt->AsAssign());
   }
+  if (stmt->IsReturn()) {
+    return ValidateReturnStatement(stmt->AsReturn());
+  }
   return true;
 }
 
diff --git a/src/validator_impl.h b/src/validator_impl.h
index 8cf7b96..6f0488e 100644
--- a/src/validator_impl.h
+++ b/src/validator_impl.h
@@ -22,8 +22,8 @@
 #include "src/ast/expression.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/module.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/statement.h"
-#include "src/ast/type/type.h"
 #include "src/ast/variable.h"
 #include "src/scope_stack.h"
 
@@ -96,11 +96,16 @@
   /// @returns true if no previous decleration with the |decl|'s name
   /// exist in the variable stack
   bool ValidateDeclStatement(const ast::VariableDeclStatement* decl);
+  /// Validates return statement
+  /// @param ret the return statement to check
+  /// @returns true if function return type matches the return statement type
+  bool ValidateReturnStatement(const ast::ReturnStatement* ret);
 
  private:
   std::string error_;
   ScopeStack<ast::Variable*> variable_stack_;
   ScopeStack<ast::Function*> function_stack_;
+  ast::Function* current_function_ = nullptr;
 };
 
 }  // namespace tint
diff --git a/src/validator_test.cc b/src/validator_test.cc
index 9f4ada0..aaed618 100644
--- a/src/validator_test.cc
+++ b/src/validator_test.cc
@@ -674,7 +674,7 @@
 
   ast::VariableList params1;
   auto func1 =
-      std::make_unique<ast::Function>("func1", std::move(params1), &f32);
+      std::make_unique<ast::Function>("func1", std::move(params1), &void_type);
   auto body1 = std::make_unique<ast::BlockStatement>();
   body1->append(std::make_unique<ast::VariableDeclStatement>(Source{13, 34},
                                                              std::move(var1)));