transform/EmitVertexPointSize: Handle entry point parameters

Generate a new struct that contains members of the original return
type with the point size appended to it, and replace return statements
as necessary.

The SPIR-V sanitizer then special-cases this builtin when handling
entry point IO to always use a RHS which is a literal.

Fixed: tint:732
Change-Id: Id718632a5e671f3e7c82a304f5bc1fc223a6c8ee
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49440
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
diff --git a/src/transform/emit_vertex_point_size.cc b/src/transform/emit_vertex_point_size.cc
index 9854149..13f76a0 100644
--- a/src/transform/emit_vertex_point_size.cc
+++ b/src/transform/emit_vertex_point_size.cc
@@ -14,10 +14,13 @@
 
 #include "src/transform/emit_vertex_point_size.h"
 
+#include <unordered_map>
 #include <utility>
 
-#include "src/ast/assignment_statement.h"
 #include "src/program_builder.h"
+#include "src/sem/function.h"
+#include "src/sem/statement.h"
+#include "src/utils/get_or_create.h"
 
 namespace tint {
 namespace transform {
@@ -26,34 +29,85 @@
 EmitVertexPointSize::~EmitVertexPointSize() = default;
 
 Output EmitVertexPointSize::Run(const Program* in, const DataMap&) {
-  if (!in->AST().Functions().HasStage(ast::PipelineStage::kVertex)) {
-    // If the module doesn't have any vertex stages, then there's nothing to do.
-    return Output(Program(in->Clone()));
-  }
-
   ProgramBuilder out;
-
   CloneContext ctx(&out, in);
 
-  Symbol pointsize = out.Symbols().New("tint_pointsize");
-
-  // Declare the pointsize builtin output variable.
-  out.Global(pointsize, out.ty.f32(), ast::StorageClass::kOutput, nullptr,
-             ast::DecorationList{
-                 out.Builtin(ast::Builtin::kPointSize),
-             });
-
-  // Add the pointsize assignment statement to the front of all vertex stages.
-  ctx.ReplaceAll([&](ast::Function* func) -> ast::Function* {
+  std::unordered_map<sem::Type*, sem::StructType*> struct_map;
+  for (auto* func : in->AST().Functions()) {
     if (func->pipeline_stage() != ast::PipelineStage::kVertex) {
-      return nullptr;  // Just clone func
+      continue;
     }
 
-    return CloneWithStatementsAtStart(&ctx, func,
-                                      {
-                                          out.Assign(pointsize, 1.0f),
-                                      });
-  });
+    auto* sem_func = in->Sem().Get(func);
+
+    // Create a struct for the return type that includes a point size member.
+    auto* new_struct =
+        utils::GetOrCreate(struct_map, sem_func->ReturnType(), [&]() {
+          // Gather struct members.
+          ast::StructMemberList new_struct_members;
+          if (auto* struct_ty = sem_func->ReturnType()->As<sem::StructType>()) {
+            for (auto* member : struct_ty->impl()->members()) {
+              new_struct_members.push_back(ctx.Clone(member));
+            }
+          } else {
+            auto* ret_type = ctx.Clone(sem_func->ReturnType());
+            auto ret_type_decos = ctx.Clone(func->return_type_decorations());
+            new_struct_members.push_back(
+                out.Member("position", ret_type, std::move(ret_type_decos)));
+          }
+
+          // Append a new member for the point size.
+          new_struct_members.push_back(
+              out.Member(out.Symbols().New("tint_pointsize"), out.ty.f32(),
+                         {out.Builtin(ast::Builtin::kPointSize)}));
+
+          // Create the new output struct.
+          return out.Structure(out.Sym(), new_struct_members);
+        });
+
+    // Replace return values using new output struct type constructors.
+    for (auto* ret : sem_func->ReturnStatements()) {
+      auto* ret_sem = in->Sem().Get(ret);
+
+      ast::ExpressionList new_ret_values;
+      if (auto* struct_ty = sem_func->ReturnType()->As<sem::StructType>()) {
+        std::function<ast::Expression*()> ret_value = [&]() {
+          return ctx.Clone(ret->value());
+        };
+
+        if (!ret->value()->Is<ast::IdentifierExpression>()) {
+          // Capture the original return value in a local temporary.
+          auto* new_struct_ty = ctx.Clone(struct_ty);
+          auto* temp = out.Const(out.Sym(), new_struct_ty, ret_value());
+          ctx.InsertBefore(ret_sem->Block()->statements(), ret, out.Decl(temp));
+          ret_value = [&, temp]() { return out.Expr(temp); };
+        }
+
+        for (auto* member : struct_ty->impl()->members()) {
+          auto member_sym = ctx.Clone(member->symbol());
+          new_ret_values.push_back(out.MemberAccessor(ret_value(), member_sym));
+        }
+      } else {
+        new_ret_values.push_back(ctx.Clone(ret->value()));
+      }
+
+      // Append the point size and replace the return statement.
+      new_ret_values.push_back(out.Expr(1.f));
+      ctx.Replace(ret, out.Return(ret->source(),
+                                  out.Construct(new_struct, new_ret_values)));
+    }
+
+    // Rewrite the function header with the new return type.
+    auto func_sym = ctx.Clone(func->symbol());
+    auto params = ctx.Clone(func->params());
+    auto* body = ctx.Clone(func->body());
+    auto decos = ctx.Clone(func->decorations());
+    auto* new_func = out.create<ast::Function>(
+        func->source(), func_sym, std::move(params), new_struct, body,
+        std::move(decos), ast::DecorationList{});
+    ctx.Replace(func, new_func);
+  }
+
   ctx.Clone();
 
   return Output(Program(std::move(out)));
diff --git a/src/transform/emit_vertex_point_size_test.cc b/src/transform/emit_vertex_point_size_test.cc
index f420115..0b70aee 100644
--- a/src/transform/emit_vertex_point_size_test.cc
+++ b/src/transform/emit_vertex_point_size_test.cc
@@ -29,7 +29,6 @@
 
 [[stage(vertex)]]
 fn entry() -> [[builtin(position)]] vec4<f32> {
-  var builtin_assignments_should_happen_before_this : f32;
   return vec4<f32>();
 }
 
@@ -38,16 +37,19 @@
 )";
 
   auto* expect = R"(
-[[builtin(pointsize)]] var<out> tint_pointsize : f32;
+struct tint_symbol {
+  [[builtin(position)]]
+  position : vec4<f32>;
+  [[builtin(pointsize)]]
+  tint_pointsize : f32;
+};
 
 fn non_entry_a() {
 }
 
 [[stage(vertex)]]
-fn entry() -> [[builtin(position)]] vec4<f32> {
-  tint_pointsize = 1.0;
-  var builtin_assignments_should_happen_before_this : f32;
-  return vec4<f32>();
+fn entry() -> tint_symbol {
+  return tint_symbol(vec4<f32>(), 1.0);
 }
 
 fn non_entry_b() {
@@ -59,6 +61,255 @@
   EXPECT_EQ(expect, str(got));
 }
 
+TEST_F(EmitVertexPointSizeTest, VertexStageBasic_Struct) {
+  auto* src = R"(
+struct VertexOut {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+};
+
+fn non_entry_a() {
+}
+
+[[stage(vertex)]]
+fn entry() -> VertexOut {
+  var output : VertexOut;
+  output.pos = vec4<f32>();
+  output.col = 0.5;
+  return output;
+}
+
+fn non_entry_b() {
+}
+)";
+
+  auto* expect = R"(
+struct tint_symbol {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+  [[builtin(pointsize)]]
+  tint_pointsize : f32;
+};
+
+struct VertexOut {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+};
+
+fn non_entry_a() {
+}
+
+[[stage(vertex)]]
+fn entry() -> tint_symbol {
+  var output : VertexOut;
+  output.pos = vec4<f32>();
+  output.col = 0.5;
+  return tint_symbol(output.pos, output.col, 1.0);
+}
+
+fn non_entry_b() {
+}
+)";
+
+  auto got = Run<EmitVertexPointSize>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+// Make sure we capture the function return value in a temporary instead of
+// re-evaluating it multiple times.
+TEST_F(EmitVertexPointSizeTest, VertexStage_ReturnStructFromFunctionCall) {
+  auto* src = R"(
+struct VertexOut {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+};
+
+fn foo() -> VertexOut {
+  var output : VertexOut;
+  output.pos = vec4<f32>();
+  output.col = 0.5;
+  return output;
+}
+
+[[stage(vertex)]]
+fn entry() -> VertexOut {
+  return foo();
+}
+)";
+
+  auto* expect = R"(
+struct tint_symbol {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+  [[builtin(pointsize)]]
+  tint_pointsize : f32;
+};
+
+struct VertexOut {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+};
+
+fn foo() -> VertexOut {
+  var output : VertexOut;
+  output.pos = vec4<f32>();
+  output.col = 0.5;
+  return output;
+}
+
+[[stage(vertex)]]
+fn entry() -> tint_symbol {
+  let tint_symbol_1 : VertexOut = foo();
+  return tint_symbol(tint_symbol_1.pos, tint_symbol_1.col, 1.0);
+}
+)";
+
+  auto got = Run<EmitVertexPointSize>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(EmitVertexPointSizeTest, VertexStage_MultipleReturnStatements) {
+  auto* src = R"(
+[[stage(vertex)]]
+fn entry([[location(0)]] toggle : u32) -> [[builtin(position)]] vec4<f32> {
+  if (toggle == 1u) {
+    return vec4<f32>(0.5, 0.5, 0.5, 0.5);
+  }
+  return vec4<f32>(1.0, 1.0, 1.0, 1.0);
+}
+)";
+
+  auto* expect = R"(
+struct tint_symbol {
+  [[builtin(position)]]
+  position : vec4<f32>;
+  [[builtin(pointsize)]]
+  tint_pointsize : f32;
+};
+
+[[stage(vertex)]]
+fn entry([[location(0)]] toggle : u32) -> tint_symbol {
+  if ((toggle == 1u)) {
+    return tint_symbol(vec4<f32>(0.5, 0.5, 0.5, 0.5), 1.0);
+  }
+  return tint_symbol(vec4<f32>(1.0, 1.0, 1.0, 1.0), 1.0);
+}
+)";
+
+  auto got = Run<EmitVertexPointSize>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
+// Test that we re-use generated structures when we've seen the original return
+// type before.
+TEST_F(EmitVertexPointSizeTest, VertexStage_MultipleShaders) {
+  auto* src = R"(
+struct VertexOut {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+};
+
+[[stage(vertex)]]
+fn entry1() -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>();
+}
+
+[[stage(vertex)]]
+fn entry2() -> [[builtin(position)]] vec4<f32> {
+  return vec4<f32>(1.0, 1.0, 1.0, 1.0);
+}
+
+[[stage(vertex)]]
+fn entry3() -> VertexOut {
+  var output : VertexOut;
+  output.pos = vec4<f32>();
+  output.col = 0.5;
+  return output;
+}
+
+[[stage(vertex)]]
+fn entry4() -> VertexOut {
+  var output : VertexOut;
+  output.pos = vec4<f32>();
+  output.col = 0.75;
+  return output;
+}
+
+)";
+
+  auto* expect = R"(
+struct tint_symbol {
+  [[builtin(position)]]
+  position : vec4<f32>;
+  [[builtin(pointsize)]]
+  tint_pointsize : f32;
+};
+
+struct tint_symbol_1 {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+  [[builtin(pointsize)]]
+  tint_pointsize_1 : f32;
+};
+
+struct VertexOut {
+  [[builtin(position)]]
+  pos : vec4<f32>;
+  [[location(0)]]
+  col : f32;
+};
+
+[[stage(vertex)]]
+fn entry1() -> tint_symbol {
+  return tint_symbol(vec4<f32>(), 1.0);
+}
+
+[[stage(vertex)]]
+fn entry2() -> tint_symbol {
+  return tint_symbol(vec4<f32>(1.0, 1.0, 1.0, 1.0), 1.0);
+}
+
+[[stage(vertex)]]
+fn entry3() -> tint_symbol_1 {
+  var output : VertexOut;
+  output.pos = vec4<f32>();
+  output.col = 0.5;
+  return tint_symbol_1(output.pos, output.col, 1.0);
+}
+
+[[stage(vertex)]]
+fn entry4() -> tint_symbol_1 {
+  var output : VertexOut;
+  output.pos = vec4<f32>();
+  output.col = 0.75;
+  return tint_symbol_1(output.pos, output.col, 1.0);
+}
+)";
+
+  auto got = Run<EmitVertexPointSize>(src);
+
+  EXPECT_EQ(expect, str(got));
+}
+
 TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
   auto* src = R"(
 [[stage(fragment)]]
@@ -87,21 +338,34 @@
 
 TEST_F(EmitVertexPointSizeTest, AttemptSymbolCollision) {
   auto* src = R"(
+struct VertexOut {
+  [[builtin(position)]]
+  tint_pointsize : vec4<f32>;
+};
+
 [[stage(vertex)]]
-fn entry() -> [[builtin(position)]] vec4<f32> {
-  var tint_pointsize : f32;
-  return vec4<f32>();
+fn entry() -> VertexOut {
+  return VertexOut(vec4<f32>());
 }
 )";
 
   auto* expect = R"(
-[[builtin(pointsize)]] var<out> tint_pointsize_1 : f32;
+struct tint_symbol {
+  [[builtin(position)]]
+  tint_pointsize : vec4<f32>;
+  [[builtin(pointsize)]]
+  tint_pointsize_1 : f32;
+};
+
+struct VertexOut {
+  [[builtin(position)]]
+  tint_pointsize : vec4<f32>;
+};
 
 [[stage(vertex)]]
-fn entry() -> [[builtin(position)]] vec4<f32> {
-  tint_pointsize_1 = 1.0;
-  var tint_pointsize : f32;
-  return vec4<f32>();
+fn entry() -> tint_symbol {
+  let tint_symbol_1 : VertexOut = VertexOut(vec4<f32>());
+  return tint_symbol(tint_symbol_1.tint_pointsize, 1.0);
 }
 )";
 
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 47b8aa6..9166b05 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -321,6 +321,20 @@
                      ast::StorageClass::kOutput, nullptr, new_decorations);
     ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
 
+    // Special case for PointSize. The EmitVertexPointSize transform will
+    // produce a struct containing a member with the [[builtin(pointsize)]]
+    // attribute. The SPIR-V reader currently requires that a variable decorated
+    // with PointSize is assigned a _literal_ 1.0 value, so generate that
+    // assignment here to prevent the RHS from using a non-literal expression.
+    if (auto* builtin =
+            ast::GetDecoration<ast::BuiltinDecoration>(new_decorations)) {
+      if (builtin->value() == ast::Builtin::kPointSize) {
+        stores.push_back(ctx.dst->Assign(ctx.dst->Expr(global_var_symbol),
+                                         ctx.dst->Expr(1.f)));
+        return;
+      }
+    }
+
     // Create the assignment instruction.
     ast::Expression* rhs = ctx.dst->Expr(store_value);
     for (auto member : member_accesses) {
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index ea34eb1..8874f5d 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -294,6 +294,99 @@
   Validate(b);
 }
 
+// Test that stores to the PointSize builtin have an RHS which is constant 1.0.
+TEST_F(BuilderTest, EntryPoint_ReturnPointSize) {
+  // struct VertexOut {
+  //   [[builtin(position)]] pos : vec4<f32>;
+  //   [[builtin(pointsize)]] pointsize : f32;
+  // };
+  //
+  // [[stage(vertex)]]
+  // fn vert_main() -> VertexOutput {
+  //   if (false) {
+  //     return VertexOutput(vec4<f32>(), 1.0);
+  //   }
+  //   return VertexOutput(vec4<f32>(1.0, 2.0, 3.0, 0.0), 1.0);
+  // }
+  auto vertex_out = Structure(
+      "VertexOut",
+      {
+          Member("pos", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)}),
+          Member("pointsize", ty.f32(), {Builtin(ast::Builtin::kPointSize)}),
+      });
+  Func("vert_main", {}, vertex_out,
+       {
+           If(Expr(false), Block(Return(Construct(
+                               vertex_out, Construct(ty.vec4<f32>()), 1.f)))),
+           Return(Construct(
+               vertex_out, Construct(ty.vec4<f32>(), 1.f, 2.f, 3.f, 0.f), 1.f)),
+       },
+       {Stage(ast::PipelineStage::kVertex)});
+
+  spirv::Builder& b = SanitizeAndBuild();
+
+  ASSERT_TRUE(b.Build());
+
+  EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %18 "vert_main" %1 %6
+OpName %1 "tint_symbol_1"
+OpName %6 "tint_symbol_2"
+OpName %11 "VertexOut"
+OpMemberName %11 0 "pos"
+OpMemberName %11 1 "pointsize"
+OpName %12 "tint_symbol_3"
+OpName %13 "tint_symbol"
+OpName %18 "vert_main"
+OpDecorate %1 BuiltIn Position
+OpDecorate %6 BuiltIn PointSize
+OpMemberDecorate %11 0 Offset 0
+OpMemberDecorate %11 1 Offset 16
+%4 = OpTypeFloat 32
+%3 = OpTypeVector %4 4
+%2 = OpTypePointer Output %3
+%5 = OpConstantNull %3
+%1 = OpVariable %2 Output %5
+%7 = OpTypePointer Output %4
+%8 = OpConstantNull %4
+%6 = OpVariable %7 Output %8
+%10 = OpTypeVoid
+%11 = OpTypeStruct %3 %4
+%9 = OpTypeFunction %10 %11
+%16 = OpConstant %4 1
+%17 = OpTypeFunction %10
+%20 = OpTypeBool
+%21 = OpConstantFalse %20
+%25 = OpConstantComposite %11 %5 %16
+%27 = OpConstant %4 2
+%28 = OpConstant %4 3
+%29 = OpConstant %4 0
+%30 = OpConstantComposite %3 %16 %27 %28 %29
+%31 = OpConstantComposite %11 %30 %16
+%12 = OpFunction %10 None %9
+%13 = OpFunctionParameter %11
+%14 = OpLabel
+%15 = OpCompositeExtract %3 %13 0
+OpStore %1 %15
+OpStore %6 %16
+OpReturn
+OpFunctionEnd
+%18 = OpFunction %10 None %17
+%19 = OpLabel
+OpSelectionMerge %22 None
+OpBranchConditional %21 %23 %22
+%23 = OpLabel
+%24 = OpFunctionCall %10 %12 %25
+OpReturn
+%22 = OpLabel
+%26 = OpFunctionCall %10 %12 %31
+OpReturn
+OpFunctionEnd
+)");
+
+  Validate(b);
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer