sem::Function: Add ReturnType()

This is the resolved, semantic, return type of the function.

Bug: tint:724
Change-Id: I4ef9f7874414a3ea48131d0102da776f6d82a729
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49526
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 471552b..f767c86 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -588,7 +588,7 @@
     }
   }
 
-  if (!func->return_type()->Is<sem::Void>()) {
+  if (!info->return_type->Is<sem::Void>()) {
     if (func->body()) {
       if (!func->get_last_statement() ||
           !func->get_last_statement()->Is<ast::ReturnStatement>()) {
@@ -809,9 +809,9 @@
   builtins.clear();
   locations.clear();
 
-  if (!func->return_type()->Is<sem::Void>()) {
+  if (!info->return_type->Is<sem::Void>()) {
     if (!validate_entry_point_decorations(func->return_type_decorations(),
-                                          func->return_type(), func->source(),
+                                          info->return_type, func->source(),
                                           ParamOrRetType::kReturnType)) {
       return false;
     }
@@ -844,9 +844,9 @@
 }
 
 bool Resolver::Function(ast::Function* func) {
-  auto* func_info = function_infos_.Create<FunctionInfo>(func);
+  auto* info = function_infos_.Create<FunctionInfo>(func);
 
-  ScopedAssignment<FunctionInfo*> sa(current_function_, func_info);
+  ScopedAssignment<FunctionInfo*> sa(current_function_, info);
 
   variable_stack_.push_scope();
   for (auto* param : func->params()) {
@@ -862,11 +862,10 @@
     }
 
     variable_stack_.set(param->symbol(), param_info);
-    func_info->parameters.emplace_back(param_info);
+    info->parameters.emplace_back(param_info);
 
     if (!ApplyStorageClassUsageToType(param->declared_storage_class(),
-                                      param->declared_type(),
-                                      param->source())) {
+                                      param_info->type, param->source())) {
       diagnostics_.add_note("while instantiating parameter " +
                                 builder_->Symbols().NameFor(param->symbol()),
                             param->source());
@@ -874,21 +873,21 @@
     }
 
     if (auto* str = param_info->type->As<sem::StructType>()) {
-      auto* info = Structure(str);
-      if (!info) {
+      auto* str_info = Structure(str);
+      if (!str_info) {
         return false;
       }
       switch (func->pipeline_stage()) {
         case ast::PipelineStage::kVertex:
-          info->pipeline_stage_uses.emplace(
+          str_info->pipeline_stage_uses.emplace(
               sem::PipelineStageUsage::kVertexInput);
           break;
         case ast::PipelineStage::kFragment:
-          info->pipeline_stage_uses.emplace(
+          str_info->pipeline_stage_uses.emplace(
               sem::PipelineStageUsage::kFragmentInput);
           break;
         case ast::PipelineStage::kCompute:
-          info->pipeline_stage_uses.emplace(
+          str_info->pipeline_stage_uses.emplace(
               sem::PipelineStageUsage::kComputeInput);
           break;
         case ast::PipelineStage::kNone:
@@ -897,7 +896,22 @@
     }
   }
 
-  if (auto* str = Canonical(func->return_type())->As<sem::StructType>()) {
+  if (func->return_type().ast || func->return_type().sem) {
+    info->return_type = func->return_type();
+    if (!info->return_type) {
+      info->return_type = Type(func->return_type().ast);
+    }
+    if (!info->return_type) {
+      return false;
+    }
+  } else {
+    info->return_type = builder_->create<sem::Void>();
+  }
+
+  info->return_type_name = info->return_type->FriendlyName(builder_->Symbols());
+  info->return_type = Canonical(info->return_type);
+
+  if (auto* str = info->return_type->As<sem::StructType>()) {
     if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
                                       func->source())) {
       diagnostics_.add_note("while instantiating return type for " +
@@ -906,21 +920,21 @@
       return false;
     }
 
-    auto* info = Structure(str);
-    if (!info) {
+    auto* str_info = Structure(str);
+    if (!str_info) {
       return false;
     }
     switch (func->pipeline_stage()) {
       case ast::PipelineStage::kVertex:
-        info->pipeline_stage_uses.emplace(
+        str_info->pipeline_stage_uses.emplace(
             sem::PipelineStageUsage::kVertexOutput);
         break;
       case ast::PipelineStage::kFragment:
-        info->pipeline_stage_uses.emplace(
+        str_info->pipeline_stage_uses.emplace(
             sem::PipelineStageUsage::kFragmentOutput);
         break;
       case ast::PipelineStage::kCompute:
-        info->pipeline_stage_uses.emplace(
+        str_info->pipeline_stage_uses.emplace(
             sem::PipelineStageUsage::kComputeOutput);
         break;
       case ast::PipelineStage::kNone:
@@ -943,15 +957,15 @@
     Mark(deco);
   }
 
-  if (!ValidateFunction(func, func_info)) {
+  if (!ValidateFunction(func, info)) {
     return false;
   }
 
   // Register the function information _after_ processing the statements. This
   // allows us to catch a function calling itself when determining the call
   // information as this function doesn't exist until it's finished.
-  symbol_to_function_[func->symbol()] = func_info;
-  function_to_info_.emplace(func, func_info);
+  symbol_to_function_[func->symbol()] = info;
+  function_to_info_.emplace(func, info);
 
   return true;
 }
@@ -1274,7 +1288,7 @@
     auto* function = iter->second;
     function_calls_.emplace(call,
                             FunctionCallInfo{function, current_statement_});
-    SetType(call, function->declaration->return_type());
+    SetType(call, function->return_type, function->return_type_name);
   }
 
   return true;
@@ -2093,8 +2107,8 @@
     auto* info = it.second;
 
     auto* sem_func = builder_->create<sem::Function>(
-        info->declaration, remap_vars(info->parameters),
-        remap_vars(info->referenced_module_vars),
+        info->declaration, const_cast<sem::Type*>(info->return_type),
+        remap_vars(info->parameters), remap_vars(info->referenced_module_vars),
         remap_vars(info->local_referenced_module_vars), info->return_statements,
         ancestor_entry_points[func->symbol()]);
     func_info_to_sem_func.emplace(info, sem_func);
@@ -2479,19 +2493,19 @@
 }
 
 bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
-  sem::Type* func_type = current_function_->declaration->return_type();
+  auto* func_type = current_function_->return_type;
 
   auto* ret_type = ret->has_value() ? TypeOf(ret->value())->UnwrapAll()
                                     : builder_->ty.void_();
 
   if (func_type->UnwrapAll() != ret_type) {
-    diagnostics_.add_error(
-        "v-000y",
-        "return statement type must match its function "
-        "return type, returned '" +
-            ret_type->FriendlyName(builder_->Symbols()) + "', expected '" +
-            func_type->FriendlyName(builder_->Symbols()) + "'",
-        ret->source());
+    diagnostics_.add_error("v-000y",
+                           "return statement type must match its function "
+                           "return type, returned '" +
+                               ret_type->FriendlyName(builder_->Symbols()) +
+                               "', expected '" +
+                               current_function_->return_type_name + "'",
+                           ret->source());
     return false;
   }
 
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index c579c62..ad2635c 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -119,6 +119,8 @@
     UniqueVector<VariableInfo*> referenced_module_vars;
     UniqueVector<VariableInfo*> local_referenced_module_vars;
     std::vector<const ast::ReturnStatement*> return_statements;
+    sem::Type const* return_type = nullptr;
+    std::string return_type_name;
 
     // List of transitive calls this function makes
     UniqueVector<FunctionInfo*> transitive_calls;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index 8ae5475..a3ddc26 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -758,6 +758,7 @@
   EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a);
   EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b);
   EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c);
+  EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
 }
 
 TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
@@ -785,6 +786,7 @@
   auto* func_sem = Sem().Get(func);
   ASSERT_NE(func_sem, nullptr);
   EXPECT_EQ(func_sem->Parameters().size(), 0u);
+  EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
 
   const auto& vars = func_sem->ReferencedModuleVariables();
   ASSERT_EQ(vars.size(), 5u);
@@ -851,6 +853,7 @@
   ASSERT_NE(func_sem, nullptr);
 
   EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
+  EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
 }
 
 TEST_F(ResolverTest, Function_ReturnStatements) {
@@ -875,6 +878,7 @@
   EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
   EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
   EXPECT_EQ(func_sem->ReturnStatements()[1], ret_foo);
+  EXPECT_TRUE(func_sem->ReturnType()->Is<sem::F32>());
 }
 
 TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
diff --git a/src/sem/function.cc b/src/sem/function.cc
index cf28d02..528f348 100644
--- a/src/sem/function.cc
+++ b/src/sem/function.cc
@@ -41,12 +41,13 @@
 }  // namespace
 
 Function::Function(ast::Function* declaration,
+                   Type* return_type,
                    std::vector<const Variable*> parameters,
                    std::vector<const Variable*> referenced_module_vars,
                    std::vector<const Variable*> local_referenced_module_vars,
                    std::vector<const ast::ReturnStatement*> return_statements,
                    std::vector<Symbol> ancestor_entry_points)
-    : Base(declaration->return_type(), GetParameters(declaration)),
+    : Base(return_type, GetParameters(declaration)),
       declaration_(declaration),
       parameters_(std::move(parameters)),
       referenced_module_vars_(std::move(referenced_module_vars)),
@@ -138,8 +139,7 @@
   VariableBindings ret;
 
   for (auto* var : ReferencedModuleVariables()) {
-    auto* unwrapped_type =
-        var->Declaration()->declared_type()->UnwrapIfNeeded();
+    auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
     auto* storage_texture = unwrapped_type->As<sem::StorageTexture>();
     if (storage_texture == nullptr) {
       continue;
@@ -156,8 +156,7 @@
   VariableBindings ret;
 
   for (auto* var : ReferencedModuleVariables()) {
-    auto* unwrapped_type =
-        var->Declaration()->declared_type()->UnwrapIfNeeded();
+    auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
     auto* storage_texture = unwrapped_type->As<sem::DepthTexture>();
     if (storage_texture == nullptr) {
       continue;
@@ -184,8 +183,7 @@
   VariableBindings ret;
 
   for (auto* var : ReferencedModuleVariables()) {
-    auto* unwrapped_type =
-        var->Declaration()->declared_type()->UnwrapIfNeeded();
+    auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
     auto* sampler = unwrapped_type->As<sem::Sampler>();
     if (sampler == nullptr || sampler->kind() != kind) {
       continue;
@@ -203,8 +201,7 @@
   VariableBindings ret;
 
   for (auto* var : ReferencedModuleVariables()) {
-    auto* unwrapped_type =
-        var->Declaration()->declared_type()->UnwrapIfNeeded();
+    auto* unwrapped_type = var->Type()->UnwrapIfNeeded();
     auto* texture = unwrapped_type->As<sem::Texture>();
     if (texture == nullptr) {
       continue;
diff --git a/src/sem/function.h b/src/sem/function.h
index 4a43d02..3b107ee 100644
--- a/src/sem/function.h
+++ b/src/sem/function.h
@@ -46,6 +46,7 @@
 
   /// Constructor
   /// @param declaration the ast::Function
+  /// @param return_type the return type of the function
   /// @param parameters the parameters to the function
   /// @param referenced_module_vars the referenced module variables
   /// @param local_referenced_module_vars the locally referenced module
@@ -53,6 +54,7 @@
   /// variables
   /// @param ancestor_entry_points the ancestor entry points
   Function(ast::Function* declaration,
+           Type* return_type,
            std::vector<const Variable*> parameters,
            std::vector<const Variable*> referenced_module_vars,
            std::vector<const Variable*> local_referenced_module_vars,