[spirv-writer] Handle non-struct entry point return values

Generate a global variable for the return value and replace return
statements with assignments to this variable.

Add a list of return statements to semantic::Function.

Bug: tint:509
Change-Id: I6bc08fcac7858b48f0eff62199d5011665284220
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44804
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/clone_context.h b/src/clone_context.h
index beb7e15..2853ad4 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -381,12 +381,12 @@
   using CloneableList = std::vector<Cloneable*>;
 
   /// A map of object in #src to their cloned equivalent in #dst
-  std::unordered_map<Cloneable*, Cloneable*> cloned_;
+  std::unordered_map<const Cloneable*, Cloneable*> cloned_;
 
   /// A map of object in #src to the list of cloned objects in #dst.
   /// Clone(const std::vector<T*>& v) will use this to insert the map-value list
   /// into the target vector/ before cloning and inserting the map-key.
-  std::unordered_map<Cloneable*, CloneableList> insert_before_;
+  std::unordered_map<const Cloneable*, CloneableList> insert_before_;
 
   /// Cloneable transform functions registered with ReplaceAll()
   std::vector<CloneableTransform> transforms_;
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 9198c39..c5711bf 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -321,6 +321,7 @@
     });
   }
   if (auto* r = stmt->As<ast::ReturnStatement>()) {
+    current_function_->return_statements.push_back(r);
     return Expression(r->value());
   }
   if (auto* s = stmt->As<ast::SwitchStatement>()) {
@@ -1215,7 +1216,7 @@
 
     auto* sem_func = builder_->create<semantic::Function>(
         info->declaration, remap_vars(info->referenced_module_vars),
-        remap_vars(info->local_referenced_module_vars),
+        remap_vars(info->local_referenced_module_vars), info->return_statements,
         ancestor_entry_points[func->symbol()]);
     func_info_to_sem_func.emplace(info, sem_func);
     sem.Add(func, sem_func);
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index d8a3d81..8fd6900 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -38,6 +38,7 @@
 class Function;
 class IdentifierExpression;
 class MemberAccessorExpression;
+class ReturnStatement;
 class UnaryOpExpression;
 class Variable;
 }  // namespace ast
@@ -92,6 +93,7 @@
     ast::Function* const declaration;
     UniqueVector<VariableInfo*> referenced_module_vars;
     UniqueVector<VariableInfo*> local_referenced_module_vars;
+    std::vector<const ast::ReturnStatement*> return_statements;
 
     // List of transitive calls this function makes
     UniqueVector<FunctionInfo*> transitive_calls;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index b85dcb8..bcaa8bc 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -858,6 +858,30 @@
   EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
 }
 
+TEST_F(ResolverTest, Function_ReturnStatements) {
+  auto* var = Var("foo", ty.f32(), ast::StorageClass::kFunction);
+
+  auto* ret_1 = create<ast::ReturnStatement>(Expr(1.f));
+  auto* ret_foo = create<ast::ReturnStatement>(Expr("foo"));
+
+  auto* func = Func("my_func", ast::VariableList{}, ty.f32(),
+                    ast::StatementList{
+                        create<ast::VariableDeclStatement>(var),
+                        If(Expr(true), Block(ret_1)),
+                        ret_foo,
+                    },
+                    ast::DecorationList{});
+
+  EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+  auto* func_sem = Sem().Get(func);
+  ASSERT_NE(func_sem, nullptr);
+
+  EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
+  EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
+  EXPECT_EQ(func_sem->ReturnStatements()[1], ret_foo);
+}
+
 TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
   auto* strct = create<ast::Struct>(
       ast::StructMemberList{Member("first_member", ty.i32()),
diff --git a/src/semantic/function.h b/src/semantic/function.h
index 7cff766..fbb5934 100644
--- a/src/semantic/function.h
+++ b/src/semantic/function.h
@@ -29,6 +29,7 @@
 class Function;
 class GroupDecoration;
 class LocationDecoration;
+class ReturnStatement;
 }  // namespace ast
 
 namespace semantic {
@@ -53,11 +54,13 @@
   /// @param declaration the ast::Function
   /// @param referenced_module_vars the referenced module variables
   /// @param local_referenced_module_vars the locally referenced module
+  /// @param return_statements the function return statements
   /// variables
   /// @param ancestor_entry_points the ancestor entry points
   Function(ast::Function* declaration,
            std::vector<const Variable*> referenced_module_vars,
            std::vector<const Variable*> local_referenced_module_vars,
+           std::vector<const ast::ReturnStatement*> return_statements,
            std::vector<Symbol> ancestor_entry_points);
 
   /// Destructor
@@ -76,6 +79,10 @@
   const std::vector<const Variable*>& LocalReferencedModuleVariables() const {
     return local_referenced_module_vars_;
   }
+  /// @returns the return statements
+  const std::vector<const ast::ReturnStatement*> ReturnStatements() const {
+    return return_statements_;
+  }
   /// @returns the ancestor entry points
   const std::vector<Symbol>& AncestorEntryPoints() const {
     return ancestor_entry_points_;
@@ -148,6 +155,7 @@
   ast::Function* const declaration_;
   std::vector<const Variable*> const referenced_module_vars_;
   std::vector<const Variable*> const local_referenced_module_vars_;
+  std::vector<const ast::ReturnStatement*> const return_statements_;
   std::vector<Symbol> const ancestor_entry_points_;
 };
 
diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc
index 4ae8885..15e0b65 100644
--- a/src/semantic/sem_function.cc
+++ b/src/semantic/sem_function.cc
@@ -57,11 +57,13 @@
 Function::Function(ast::Function* declaration,
                    std::vector<const Variable*> referenced_module_vars,
                    std::vector<const Variable*> local_referenced_module_vars,
+                   std::vector<const ast::ReturnStatement*> return_statements,
                    std::vector<Symbol> ancestor_entry_points)
     : Base(declaration->return_type(), GetParameters(declaration)),
       declaration_(declaration),
       referenced_module_vars_(std::move(referenced_module_vars)),
       local_referenced_module_vars_(std::move(local_referenced_module_vars)),
+      return_statements_(std::move(return_statements)),
       ancestor_entry_points_(std::move(ancestor_entry_points)) {}
 
 Function::~Function() = default;
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 97ce585..d404d4a 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -17,7 +17,9 @@
 #include <string>
 #include <utility>
 
+#include "src/ast/return_statement.h"
 #include "src/program_builder.h"
+#include "src/semantic/function.h"
 #include "src/semantic/variable.h"
 
 namespace tint {
@@ -102,6 +104,8 @@
       continue;
     }
 
+    auto* sem_func = ctx.src->Sem().Get(func);
+
     for (auto* param : func->params()) {
       // TODO(jrprice): Handle structures by moving the declaration and
       // construction to the function body.
@@ -126,21 +130,37 @@
       }
     }
 
-    // TODO(jrprice): Hoist the return type out to a global variable, and
-    // replace return statements with variable assignments.
     if (!func->return_type()->Is<type::Void>()) {
-      TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
-          << "entry point return values are not yet supported";
-      continue;
+      // TODO(jrprice): Handle structures by creating a variable for each member
+      // and replacing return statements with extracts+stores.
+      if (func->return_type()->UnwrapAll()->Is<type::Struct>()) {
+        TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
+            << "structures as entry point return values are not yet supported";
+        continue;
+      }
+
+      // Create a new symbol for the global variable.
+      auto var_symbol = ctx.dst->Symbols().New();
+      // Create the global variable.
+      auto* var = ctx.dst->Var(var_symbol, ctx.Clone(func->return_type()),
+                               ast::StorageClass::kOutput, nullptr,
+                               ctx.Clone(func->return_type_decorations()));
+      ctx.InsertBefore(func, var);
+
+      // Replace all return statements with stores to the global variable.
+      for (auto* ret : sem_func->ReturnStatements()) {
+        ctx.InsertBefore(
+            ret, ctx.dst->create<ast::AssignmentStatement>(
+                     ctx.dst->Expr(var_symbol), ctx.Clone(ret->value())));
+        ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
+      }
     }
 
-    // Rewrite the function header to remove the parameters.
-    // TODO(jrprice): Change return type to void when return values are handled.
+    // Rewrite the function header to remove the parameters and return value.
     auto* new_func = ctx.dst->create<ast::Function>(
         func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
-        ctx.Clone(func->return_type()), ctx.Clone(func->body()),
-        ctx.Clone(func->decorations()),
-        ctx.Clone(func->return_type_decorations()));
+        ctx.dst->ty.void_(), ctx.Clone(func->body()),
+        ctx.Clone(func->decorations()), ast::DecorationList{});
     ctx.Replace(func, new_func);
   }
 }
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
index 2f03ad4..2a22b67 100644
--- a/src/transform/spirv_test.cc
+++ b/src/transform/spirv_test.cc
@@ -86,6 +86,97 @@
   EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnBuiltin) {
+  auto* src = R"(
+[[stage(vertex)]]
+fn vert_main() -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(1.0, 2.0, 3.0, 0.0);
+}
+)";
+
+  auto* expect = R"(
+[[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>;
+
+[[stage(vertex)]]
+fn vert_main() -> void {
+  tint_symbol_1 = vec4<f32>(1.0, 2.0, 3.0, 0.0);
+  return;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation) {
+  auto* src = R"(
+[[stage(fragment)]]
+fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
+  if (loc_in > 10u) {
+    return 0.5;
+  }
+  return 1.0;
+}
+)";
+
+  auto* expect = R"(
+[[location(0)]] var<in> tint_symbol_1 : u32;
+
+[[location(0)]] var<out> tint_symbol_2 : f32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+  if ((tint_symbol_1 > 10u)) {
+    tint_symbol_2 = 0.5;
+    return;
+  }
+  tint_symbol_2 = 1.0;
+  return;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation_TypeAlias) {
+  auto* src = R"(
+type myf32 = f32;
+
+[[stage(fragment)]]
+fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] myf32 {
+  if (loc_in > 10u) {
+    return 0.5;
+  }
+  return 1.0;
+}
+)";
+
+  auto* expect = R"(
+type myf32 = f32;
+
+[[location(0)]] var<in> tint_symbol_1 : u32;
+
+[[location(0)]] var<out> tint_symbol_2 : myf32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+  if ((tint_symbol_1 > 10u)) {
+    tint_symbol_2 = 0.5;
+    return;
+  }
+  tint_symbol_2 = 1.0;
+  return;
+}
+)";
+
+  auto got = Run<Spirv>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
 TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
   auto* src = R"(
 [[builtin(sample_index)]] var<in> sample_index : u32;
@@ -164,27 +255,26 @@
 
 // Test that different transforms within the sanitizer interact correctly.
 TEST_F(SpirvTest, MultipleTransforms) {
-  // TODO(jrprice): Make `mask_out` a return value when supported.
   auto* src = R"(
-[[builtin(sample_mask_out)]] var<out> mask_out : u32;
-
 [[stage(fragment)]]
 fn main([[builtin(sample_index)]] sample_index : u32,
-        [[builtin(sample_mask_in)]] mask_in : u32) -> void {
-  mask_out = mask_in;
+        [[builtin(sample_mask_in)]] mask_in : u32)
+        -> [[builtin(sample_mask_out)]] u32 {
+  return mask_in;
 }
 )";
 
   auto* expect = R"(
-[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
-
 [[builtin(sample_index)]] var<in> tint_symbol_1 : u32;
 
 [[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>;
 
+[[builtin(sample_mask_out)]] var<out> tint_symbol_3 : array<u32, 1>;
+
 [[stage(fragment)]]
 fn main() -> void {
-  mask_out[0] = tint_symbol_2[0];
+  tint_symbol_3[0] = tint_symbol_2[0];
+  return;
 }
 )";
 
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index a0c52c3..a45e506 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -18,6 +18,7 @@
 #include "src/ast/builtin.h"
 #include "src/ast/builtin_decoration.h"
 #include "src/ast/location_decoration.h"
+#include "src/ast/return_statement.h"
 #include "src/ast/stage_decoration.h"
 #include "src/ast/storage_class.h"
 #include "src/ast/variable.h"
@@ -96,6 +97,74 @@
 )");
 }
 
+TEST_F(BuilderTest, EntryPoint_ReturnValue) {
+  // [[stage(fragment)]]
+  // fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
+  //   if (loc_in > 10) {
+  //     return 0.5;
+  //   }
+  //   return 1.0;
+  // }
+  auto* f32 = ty.f32();
+  auto* u32 = ty.u32();
+  auto* loc_in = Var("loc_in", u32, ast::StorageClass::kFunction, nullptr,
+                     {create<ast::LocationDecoration>(0)});
+  auto* cond = create<ast::BinaryExpression>(ast::BinaryOp::kGreaterThan,
+                                             Expr("loc_in"), Expr(10u));
+  Func("frag_main", ast::VariableList{loc_in}, f32,
+       ast::StatementList{
+           If(cond, Block(create<ast::ReturnStatement>(Expr(0.5f)))),
+           create<ast::ReturnStatement>(Expr(1.0f)),
+       },
+       ast::DecorationList{
+           create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+       },
+       ast::DecorationList{create<ast::LocationDecoration>(0)});
+
+  spirv::Builder& b = SanitizeAndBuild();
+
+  ASSERT_TRUE(b.Build());
+
+  // Test that the return value gets hoisted out to a global variable with the
+  // Output storage class, and the return statements are replaced with stores.
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %10 "frag_main" %1 %4
+OpExecutionMode %10 OriginUpperLeft
+OpName %1 "tint_symbol_1"
+OpName %4 "tint_symbol_2"
+OpName %10 "frag_main"
+OpDecorate %1 Location 0
+OpDecorate %4 Location 0
+%3 = OpTypeInt 32 0
+%2 = OpTypePointer Input %3
+%1 = OpVariable %2 Input
+%6 = OpTypeFloat 32
+%5 = OpTypePointer Output %6
+%7 = OpConstantNull %6
+%4 = OpVariable %5 Output %7
+%9 = OpTypeVoid
+%8 = OpTypeFunction %9
+%13 = OpConstant %3 10
+%15 = OpTypeBool
+%18 = OpConstant %6 0.5
+%19 = OpConstant %6 1
+%10 = OpFunction %9 None %8
+%11 = OpLabel
+%12 = OpLoad %3 %1
+%14 = OpUGreaterThan %15 %12 %13
+OpSelectionMerge %16 None
+OpBranchConditional %14 %17 %16
+%17 = OpLabel
+OpStore %4 %18
+OpReturn
+%16 = OpLabel
+OpStore %4 %19
+OpReturn
+OpFunctionEnd
+)");
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer