Move return validation from Validator to Resolver

Improved error message to use friendly names. Fixed tests that broke as
a result of this change.

Bug: tint:642
Change-Id: I9a1e819e1a6110a89c826936b96ab84f7f79a084
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45582
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc
index 0da990a..ba839de 100644
--- a/src/resolver/function_validation_test.cc
+++ b/src/resolver/function_validation_test.cc
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "src/ast/return_statement.h"
+#include "src/ast/stage_decoration.h"
 #include "src/resolver/resolver.h"
 #include "src/resolver/resolver_test_helper.h"
 
@@ -74,5 +75,111 @@
       "12:34 error v-0002: non-void function must end with a return statement");
 }
 
+TEST_F(ResolverFunctionValidationTest,
+       FunctionTypeMustMatchReturnStatementType_Pass) {
+  // [[stage(vertex)]]
+  // fn func -> void { return; }
+
+  Func("func", ast::VariableList{}, ty.void_(),
+       ast::StatementList{
+           create<ast::ReturnStatement>(),
+       },
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+       });
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+       FunctionTypeMustMatchReturnStatementType_fail) {
+  // fn func -> void { return 2; }
+  Func("func", ast::VariableList{}, ty.void_(),
+       ast::StatementList{
+           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
+                                        Expr(2)),
+       },
+       ast::DecorationList{});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error v-000y: return statement type must match its function "
+            "return type, returned 'i32', expected 'void'");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+       FunctionTypeMustMatchReturnStatementTypeF32_pass) {
+  // fn func -> f32 { return 2.0; }
+  Func("func", ast::VariableList{}, ty.f32(),
+       ast::StatementList{
+           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
+                                        Expr(2.f)),
+       },
+       ast::DecorationList{});
+  Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{},
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+       });
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+       FunctionTypeMustMatchReturnStatementTypeF32_fail) {
+  // fn func -> f32 { return 2; }
+  Func("func", ast::VariableList{}, ty.f32(),
+       ast::StatementList{
+           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
+                                        Expr(2)),
+       },
+       ast::DecorationList{});
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error v-000y: return statement type must match its function "
+            "return type, returned 'i32', expected 'f32'");
+}
+
+TEST_F(ResolverFunctionValidationTest,
+       FunctionTypeMustMatchReturnStatementTypeF32Alias_pass) {
+  // type myf32 = f32;
+  // fn func -> myf32 { return 2.0; }
+  auto* myf32 = ty.alias("myf32", ty.f32());
+  Func("func", ast::VariableList{}, myf32,
+       ast::StatementList{
+           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
+                                        Expr(2.f)),
+       },
+       ast::DecorationList{});
+  Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{},
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+       });
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+}
+
+TEST_F(ResolverFunctionValidationTest,
+       FunctionTypeMustMatchReturnStatementTypeF32Alias_fail) {
+  // type myf32 = f32;
+  // fn func -> myf32 { return 2; }
+  auto* myf32 = ty.alias("myf32", ty.f32());
+  Func("func", ast::VariableList{}, myf32,
+       ast::StatementList{
+           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
+                                        Expr(2u)),
+       },
+       ast::DecorationList{});
+  Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{},
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+       });
+
+  EXPECT_FALSE(r()->Resolve());
+  EXPECT_EQ(r()->error(),
+            "12:34 error v-000y: return statement type must match its function "
+            "return type, returned 'u32', expected 'myf32'");
+}
+
 }  // namespace
 }  // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index fa5a25f..b4915cf 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -406,8 +406,7 @@
     });
   }
   if (auto* r = stmt->As<ast::ReturnStatement>()) {
-    current_function_->return_statements.push_back(r);
-    return Expression(r->value());
+    return Return(r);
   }
   if (auto* s = stmt->As<ast::SwitchStatement>()) {
     if (!Expression(s->condition())) {
@@ -1647,6 +1646,36 @@
   return info;
 }
 
+bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
+  type::Type* func_type = current_function_->declaration->return_type();
+
+  auto* ret_type = ret->has_value() ? TypeOf(ret->value())->UnwrapAll()
+                                    : builder_->ty.void_();
+
+  if (func_type->UnwrapAll() != ret_type) {
+    diagnostics_.add_error(
+        "v-000y",
+        "return statement type must match its function "
+        "return type, returned '" +
+            ret_type->FriendlyName(builder_->Symbols()) + "', expected '" +
+            func_type->FriendlyName(builder_->Symbols()) + "'",
+        ret->source());
+    return false;
+  }
+
+  return true;
+}
+
+bool Resolver::Return(ast::ReturnStatement* ret) {
+  current_function_->return_statements.push_back(ret);
+
+  auto result = Expression(ret->value());
+
+  // Validate after processing the return value expression so that its type is
+  // available for validation
+  return result && ValidateReturn(ret);
+}
+
 bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
                                             type::Type* ty,
                                             Source usage) {
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 97938c1..13a0453 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -219,6 +219,7 @@
   bool Statements(const ast::StatementList&);
   bool UnaryOp(ast::UnaryOpExpression*);
   bool VariableDeclStatement(const ast::VariableDeclStatement*);
+  bool Return(ast::ReturnStatement* ret);
 
   // AST and Type validation methods
   // Each return true on success, false on failure.
@@ -226,6 +227,7 @@
   bool ValidateParameter(const ast::Variable* param);
   bool ValidateFunction(const ast::Function* func);
   bool ValidateStructure(const type::Struct* st);
+  bool ValidateReturn(const ast::ReturnStatement* ret);
 
   /// @returns the semantic information for the array `arr`, building it if it
   /// hasn't been constructed already. If an error is raised, nullptr is
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 57b0820..53e35b0 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -228,7 +228,7 @@
   auto* cond = Expr(2);
 
   auto* ret = create<ast::ReturnStatement>(cond);
-  WrapInFunction(ret);
+  Func("test", {}, ty.i32(), {ret}, {});
 
   EXPECT_TRUE(r()->Resolve()) << r()->error();
 
diff --git a/src/transform/renamer_test.cc b/src/transform/renamer_test.cc
index 887cfaa..cfd3290 100644
--- a/src/transform/renamer_test.cc
+++ b/src/transform/renamer_test.cc
@@ -83,7 +83,7 @@
 TEST_F(RenamerTest, PreserveSwizzles) {
   auto* src = R"(
 [[stage(vertex)]]
-fn entry() -> void {
+fn entry() -> vec4<f32> {
   var v : vec4<f32>;
   var rgba : f32;
   var xyzw : f32;
@@ -93,7 +93,7 @@
 
   auto* expect = R"(
 [[stage(vertex)]]
-fn _tint_1() -> void {
+fn _tint_1() -> vec4<f32> {
   var _tint_2 : vec4<f32>;
   var _tint_3 : f32;
   var _tint_4 : f32;
@@ -120,7 +120,7 @@
 TEST_F(RenamerTest, PreserveIntrinsics) {
   auto* src = R"(
 [[stage(vertex)]]
-fn entry() -> void {
+fn entry() -> vec4<f32> {
   var blah : vec4<f32>;
   return abs(blah);
 }
@@ -128,7 +128,7 @@
 
   auto* expect = R"(
 [[stage(vertex)]]
-fn _tint_1() -> void {
+fn _tint_1() -> vec4<f32> {
   var _tint_2 : vec4<f32>;
   return abs(_tint_2);
 }
diff --git a/src/validator/validator_function_test.cc b/src/validator/validator_function_test.cc
index e5cc9bc..d74e103 100644
--- a/src/validator/validator_function_test.cc
+++ b/src/validator/validator_function_test.cc
@@ -56,123 +56,6 @@
   EXPECT_TRUE(v.Validate());
 }
 
-TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) {
-  // [[stage(vertex)]]
-  // fn func -> void { return; }
-
-  Func("func", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(),
-       },
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
-       });
-
-  ValidatorImpl& v = Build();
-
-  EXPECT_TRUE(v.Validate()) << v.error();
-}
-
-TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) {
-  // fn func -> void { return 2; }
-  Func("func", ast::VariableList{}, ty.void_(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
-                                        Expr(2)),
-       },
-       ast::DecorationList{});
-
-  ValidatorImpl& v = Build();
-
-  EXPECT_FALSE(v.Validate());
-  // TODO(sarahM0): replace 000y with a rule number
-  EXPECT_EQ(v.error(),
-            "12:34 v-000y: return statement type must match its function "
-            "return type, returned '__i32', expected '__void'");
-}
-
-TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_pass) {
-  // fn func -> f32 { return 2.0; }
-  Func("func", ast::VariableList{}, ty.f32(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
-                                        Expr(2.f)),
-       },
-       ast::DecorationList{});
-  Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
-       });
-
-  ValidatorImpl& v = Build();
-
-  EXPECT_TRUE(v.Validate());
-}
-
-TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) {
-  // fn func -> f32 { return 2; }
-  Func("func", ast::VariableList{}, ty.f32(),
-       ast::StatementList{
-           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
-                                        Expr(2)),
-       },
-       ast::DecorationList{});
-
-  ValidatorImpl& v = Build();
-
-  EXPECT_FALSE(v.Validate());
-  // TODO(sarahM0): replace 000y with a rule number
-  EXPECT_EQ(v.error(),
-            "12:34 v-000y: return statement type must match its function "
-            "return type, returned '__i32', expected '__f32'");
-}
-
-TEST_F(ValidateFunctionTest,
-       FunctionTypeMustMatchReturnStatementTypeF32Alias_pass) {
-  // type myf32 = f32;
-  // fn func -> myf32 { return 2.0; }
-  auto* myf32 = ty.alias("myf32", ty.f32());
-  Func("func", ast::VariableList{}, myf32,
-       ast::StatementList{
-           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
-                                        Expr(2.f)),
-       },
-       ast::DecorationList{});
-  Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
-       });
-
-  ValidatorImpl& v = Build();
-
-  EXPECT_TRUE(v.Validate());
-}
-
-TEST_F(ValidateFunctionTest,
-       FunctionTypeMustMatchReturnStatementTypeF32Alias_fail) {
-  // type myf32 = f32;
-  // fn func -> myf32 { return 2; }
-  auto* myf32 = ty.alias("myf32", ty.f32());
-  Func("func", ast::VariableList{}, myf32,
-       ast::StatementList{
-           create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
-                                        Expr(2u)),
-       },
-       ast::DecorationList{});
-  Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{},
-       ast::DecorationList{
-           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
-       });
-
-  ValidatorImpl& v = Build();
-
-  EXPECT_FALSE(v.Validate());
-  EXPECT_EQ(
-      v.error(),
-      "12:34 v-000y: return statement type must match its function "
-      "return type, returned '__u32', expected '__alias_tint_symbol_1__f32'");
-}
-
 TEST_F(ValidateFunctionTest, PipelineStage_MustBeUnique_Fail) {
   // [[stage(fragment)]]
   // [[stage(vertex)]]
diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc
index 65e2a33..16187d9 100644
--- a/src/validator/validator_impl.cc
+++ b/src/validator/validator_impl.cc
@@ -182,28 +182,6 @@
   return true;
 }
 
-bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) {
-  // TODO(sarahM0): update this when this issue resolves:
-  // https://github.com/gpuweb/gpuweb/issues/996
-  type::Type* func_type = current_function_->return_type();
-
-  type::Void void_type;
-  auto* ret_type = ret->has_value()
-                       ? program_->Sem().Get(ret->value())->Type()->UnwrapAll()
-                       : &void_type;
-
-  if (func_type->UnwrapAll()->type_name() != ret_type->type_name()) {
-    add_error(ret->source(), "v-000y",
-              "return statement type must match its function return "
-              "type, returned '" +
-                  ret_type->type_name() + "', expected '" +
-                  func_type->type_name() + "'");
-    return false;
-  }
-
-  return true;
-}
-
 bool ValidatorImpl::ValidateStatements(const ast::BlockStatement* block) {
   if (!block) {
     return false;
@@ -262,9 +240,6 @@
   if (auto* a = stmt->As<ast::AssignmentStatement>()) {
     return ValidateAssign(a);
   }
-  if (auto* r = stmt->As<ast::ReturnStatement>()) {
-    return ValidateReturnStatement(r);
-  }
   if (auto* s = stmt->As<ast::SwitchStatement>()) {
     return ValidateSwitch(s);
   }
diff --git a/src/validator/validator_impl.h b/src/validator/validator_impl.h
index 2ec4c22..04edf95 100644
--- a/src/validator/validator_impl.h
+++ b/src/validator/validator_impl.h
@@ -99,10 +99,6 @@
   /// @returns true if no previous declaration with the `decl` 's name
   /// exist in the variable stack
   bool ValidateDeclStatement(const ast::VariableDeclStatement* decl);
-  /// Validates return statement
-  /// @param ret the return statement to check
-  /// @returns true if function return type matches the return statement type
-  bool ValidateReturnStatement(const ast::ReturnStatement* ret);
   /// Validates switch statements
   /// @param s the switch statement to check
   /// @returns true if the valdiation was successful
diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc
index de6fc29..3f53110 100644
--- a/src/writer/spirv/builder_call_test.cc
+++ b/src/writer/spirv/builder_call_test.cc
@@ -79,7 +79,7 @@
   func_params.push_back(Var("b", ty.f32(), ast::StorageClass::kFunction));
 
   auto* a_func =
-      Func("a_func", func_params, ty.void_(),
+      Func("a_func", func_params, ty.f32(),
            ast::StatementList{create<ast::ReturnStatement>(Add("a", "b"))},
            ast::DecorationList{});
 
@@ -96,27 +96,27 @@
   ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
 
   EXPECT_TRUE(b.GenerateStatement(expr)) << b.error();
-  EXPECT_EQ(DumpBuilder(b), R"(OpName %4 "a_func"
-OpName %5 "a"
-OpName %6 "b"
+  EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
+OpName %4 "a"
+OpName %5 "b"
 OpName %12 "main"
-%2 = OpTypeVoid
-%3 = OpTypeFloat 32
-%1 = OpTypeFunction %2 %3 %3
-%11 = OpTypeFunction %2
-%15 = OpConstant %3 1
-%4 = OpFunction %2 None %1
-%5 = OpFunctionParameter %3
-%6 = OpFunctionParameter %3
-%7 = OpLabel
-%8 = OpLoad %3 %5
-%9 = OpLoad %3 %6
-%10 = OpFAdd %3 %8 %9
-OpReturnValue %10
+%2 = OpTypeFloat 32
+%1 = OpTypeFunction %2 %2 %2
+%11 = OpTypeVoid
+%10 = OpTypeFunction %11
+%15 = OpConstant %2 1
+%3 = OpFunction %2 None %1
+%4 = OpFunctionParameter %2
+%5 = OpFunctionParameter %2
+%6 = OpLabel
+%7 = OpLoad %2 %4
+%8 = OpLoad %2 %5
+%9 = OpFAdd %2 %7 %8
+OpReturnValue %9
 OpFunctionEnd
-%12 = OpFunction %2 None %11
+%12 = OpFunction %11 None %10
 %13 = OpLabel
-%14 = OpFunctionCall %2 %4 %15 %15
+%14 = OpFunctionCall %2 %3 %15 %15
 OpReturn
 OpFunctionEnd
 )");
diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc
index 392a05e..e92680d 100644
--- a/src/writer/spirv/builder_function_test.cc
+++ b/src/writer/spirv/builder_function_test.cc
@@ -65,7 +65,7 @@
 TEST_F(BuilderTest, Function_Terminator_ReturnValue) {
   Global("a", ty.f32(), ast::StorageClass::kPrivate);
 
-  Func("a_func", {}, ty.void_(),
+  Func("a_func", {}, ty.f32(),
        ast::StatementList{create<ast::ReturnStatement>(Expr("a"))},
        ast::DecorationList{});
 
@@ -77,17 +77,16 @@
   ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error();
   ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
   EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "a"
-OpName %7 "a_func"
+OpName %6 "a_func"
 %3 = OpTypeFloat 32
 %2 = OpTypePointer Private %3
 %4 = OpConstantNull %3
 %1 = OpVariable %2 Private %4
-%6 = OpTypeVoid
-%5 = OpTypeFunction %6
-%7 = OpFunction %6 None %5
-%8 = OpLabel
-%9 = OpLoad %3 %1
-OpReturnValue %9
+%5 = OpTypeFunction %3
+%6 = OpFunction %3 None %5
+%7 = OpLabel
+%8 = OpLoad %3 %1
+OpReturnValue %8
 OpFunctionEnd
 )");
 }
diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc
index 6918e86..90215a5 100644
--- a/src/writer/spirv/builder_if_test.cc
+++ b/src/writer/spirv/builder_if_test.cc
@@ -511,14 +511,10 @@
   // if (true) {
   //   return false;
   // }
-  auto* if_body = create<ast::BlockStatement>(ast::StatementList{
-      create<ast::ReturnStatement>(Expr(false)),
-  });
-
-  auto* expr =
-      create<ast::IfStatement>(Expr(true), if_body, ast::ElseStatementList{});
-  WrapInFunction(expr);
-
+  // return true;
+  auto* if_body = Block(Return(Expr(false)));
+  auto* expr = If(Expr(true), if_body);
+  Func("test", {}, ty.bool_(), {expr, Return(Expr(true))}, {});
   spirv::Builder& b = Build();
 
   b.push_function(Function{});
diff --git a/src/writer/spirv/builder_return_test.cc b/src/writer/spirv/builder_return_test.cc
index 17265b8..3169a2f 100644
--- a/src/writer/spirv/builder_return_test.cc
+++ b/src/writer/spirv/builder_return_test.cc
@@ -39,7 +39,7 @@
   auto* val = vec3<f32>(1.f, 1.f, 3.f);
 
   auto* ret = create<ast::ReturnStatement>(val);
-  WrapInFunction(ret);
+  Func("test", {}, ty.vec3<f32>(), {ret}, {});
 
   spirv::Builder& b = Build();
 
@@ -62,7 +62,7 @@
   auto* var = Global("param", ty.f32(), ast::StorageClass::kFunction);
 
   auto* ret = create<ast::ReturnStatement>(Expr("param"));
-  WrapInFunction(ret);
+  Func("test", {}, ty.f32(), {ret}, {});
 
   spirv::Builder& b = Build();