[spirv-writer] Support optional trailing return.

This CL updates the SPIRV-Writer to inject an OpReturn as the trailing
statement in a function if the function does not end with a `discard` or
a `return` statement.

R=bclayton@google.com, dneto@google.com

Fixes: tint:302
Change-Id: I2e7c7beff15ad30c779c591bb75cf97fc0960bf7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33160
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc
index 1ca5c0e..1fc8aa4 100644
--- a/src/writer/spirv/builder_call_test.cc
+++ b/src/writer/spirv/builder_call_test.cc
@@ -95,6 +95,7 @@
 %12 = OpFunction %11 None %10
 %13 = OpLabel
 %14 = OpFunctionCall %2 %3 %15 %15
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -156,6 +157,7 @@
 %12 = OpFunction %2 None %11
 %13 = OpLabel
 %14 = OpFunctionCall %2 %4 %15 %15
+OpReturn
 OpFunctionEnd
 )");
 }
diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc
index a68f780..8b59b2e 100644
--- a/src/writer/spirv/builder_function_decoration_test.cc
+++ b/src/writer/spirv/builder_function_decoration_test.cc
@@ -258,9 +258,11 @@
 %1 = OpTypeFunction %2
 %3 = OpFunction %2 None %1
 %4 = OpLabel
+OpReturn
 OpFunctionEnd
 %5 = OpFunction %2 None %1
 %6 = OpLabel
+OpReturn
 OpFunctionEnd
 )");
 }
diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc
index 528ee18..1a6de92 100644
--- a/src/writer/spirv/builder_function_test.cc
+++ b/src/writer/spirv/builder_function_test.cc
@@ -18,6 +18,7 @@
 #include "spirv/unified1/spirv.h"
 #include "spirv/unified1/spirv.hpp11"
 #include "src/ast/decorated_variable.h"
+#include "src/ast/discard_statement.h"
 #include "src/ast/function.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/member_accessor_expression.h"
@@ -51,15 +52,83 @@
   ast::Function func("a_func", {}, &void_type, create<ast::BlockStatement>());
 
   ASSERT_TRUE(b.GenerateFunction(&func));
-  EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "tint_615f66756e63"
-)");
-  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
+  EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "tint_615f66756e63"
+%2 = OpTypeVoid
 %1 = OpTypeFunction %2
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+OpReturn
+OpFunctionEnd
 )");
+}
 
-  ASSERT_GE(b.functions().size(), 1u);
-  const auto& ret = b.functions()[0];
-  EXPECT_EQ(DumpInstruction(ret.declaration()), R"(%3 = OpFunction %2 None %1
+TEST_F(BuilderTest, Function_Terminator_Return) {
+  ast::type::VoidType void_type;
+
+  auto* body = create<ast::BlockStatement>();
+  body->append(create<ast::ReturnStatement>());
+
+  ast::Function func("a_func", {}, &void_type, body);
+
+  ASSERT_TRUE(b.GenerateFunction(&func));
+  EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "tint_615f66756e63"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(BuilderTest, Function_Terminator_ReturnValue) {
+  ast::type::VoidType void_type;
+  ast::type::F32Type f32;
+
+  auto* var_a = create<ast::Variable>("a", ast::StorageClass::kPrivate, &f32);
+  td.RegisterVariableForTesting(var_a);
+
+  auto* body = create<ast::BlockStatement>();
+  body->append(
+      create<ast::ReturnStatement>(create<ast::IdentifierExpression>("a")));
+  ASSERT_TRUE(td.DetermineResultType(body)) << td.error();
+
+  ast::Function func("a_func", {}, &void_type, body);
+
+  ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error();
+  ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+  EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "tint_61"
+OpName %7 "tint_615f66756e63"
+%3 = OpTypeFloat 32
+%2 = OpTypePointer Private %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Private %4
+%6 = OpTypeVoid
+%5 = OpTypeFunction %6
+%7 = OpFunction %6 None %5
+%8 = OpLabel
+%9 = OpLoad %3 %1
+OpReturnValue %9
+OpFunctionEnd
+)");
+}
+
+TEST_F(BuilderTest, Function_Terminator_Discard) {
+  ast::type::VoidType void_type;
+
+  auto* body = create<ast::BlockStatement>();
+  body->append(create<ast::DiscardStatement>());
+
+  ast::Function func("a_func", {}, &void_type, body);
+
+  ASSERT_TRUE(b.GenerateFunction(&func));
+  EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "tint_615f66756e63"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+OpKill
+OpFunctionEnd
 )");
 }
 
diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc
index de8cecf..c877885 100644
--- a/src/writer/spirv/builder_intrinsic_test.cc
+++ b/src/writer/spirv/builder_intrinsic_test.cc
@@ -940,6 +940,7 @@
 %8 = OpLabel
 %11 = OpLoad %3 %1
 %9 = OpExtInst %3 %10 Round %11
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -967,6 +968,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -994,6 +996,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1042,6 +1045,7 @@
 %3 = OpFunction %2 None %1
 %4 = OpLabel
 %5 = OpExtInst %6 %7 Length %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1066,6 +1070,7 @@
 %3 = OpFunction %2 None %1
 %4 = OpLabel
 %5 = OpExtInst %6 %7 Length %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1090,6 +1095,7 @@
 %3 = OpFunction %2 None %1
 %4 = OpLabel
 %5 = OpExtInst %6 %8 Normalize %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1118,6 +1124,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8 %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1146,6 +1153,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10 %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1177,6 +1185,7 @@
 %3 = OpFunction %2 None %1
 %4 = OpLabel
 %5 = OpExtInst %6 %7 Distance %8 %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1202,6 +1211,7 @@
 %3 = OpFunction %2 None %1
 %4 = OpLabel
 %5 = OpExtInst %6 %7 Distance %10 %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1228,6 +1238,7 @@
 %3 = OpFunction %2 None %1
 %4 = OpLabel
 %5 = OpExtInst %6 %8 Cross %10 %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1255,6 +1266,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8 %8 %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1284,6 +1296,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10 %10 %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1320,6 +1333,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1347,6 +1361,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1377,6 +1392,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1404,6 +1420,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1434,6 +1451,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8 %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1461,6 +1479,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10 %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1492,6 +1511,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8 %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1519,6 +1539,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10 %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1550,6 +1571,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8 %8 %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1579,6 +1601,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10 %10 %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1609,6 +1632,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %7 )" + param.op +
                                 R"( %8 %8 %8
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1638,6 +1662,7 @@
 %4 = OpLabel
 %5 = OpExtInst %6 %8 )" + param.op +
                                 R"( %10 %10 %10
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -1673,6 +1698,7 @@
 %4 = OpLabel
 %13 = OpLoad %7 %5
 %11 = OpExtInst %9 %12 Determinant %13
+OpReturn
 OpFunctionEnd
 )");
 }
diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc
index 2738b28..b4b110b 100644
--- a/src/writer/spirv/builder_switch_test.cc
+++ b/src/writer/spirv/builder_switch_test.cc
@@ -141,6 +141,7 @@
 %11 = OpLabel
 OpBranch %9
 %9 = OpLabel
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -198,6 +199,7 @@
 OpStore %1 %12
 OpBranch %9
 %9 = OpLabel
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -288,6 +290,7 @@
 OpStore %1 %16
 OpBranch %9
 %9 = OpLabel
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -379,6 +382,7 @@
 OpStore %1 %16
 OpBranch %9
 %9 = OpLabel
+OpReturn
 OpFunctionEnd
 )");
 }
@@ -501,6 +505,7 @@
 %11 = OpLabel
 OpBranch %9
 %9 = OpLabel
+OpReturn
 OpFunctionEnd
 )");
 }
diff --git a/src/writer/spirv/function.cc b/src/writer/spirv/function.cc
index 01da96c..5c5c726 100644
--- a/src/writer/spirv/function.cc
+++ b/src/writer/spirv/function.cc
@@ -17,6 +17,15 @@
 namespace tint {
 namespace writer {
 namespace spirv {
+namespace {
+
+// Returns true if the given Op is a function terminator
+bool OpIsFunctionTerminator(spv::Op op) {
+  return op == spv::Op::OpReturn || op == spv::Op::OpReturnValue ||
+         op == spv::Op::OpKill;
+}
+
+}  // namespace
 
 Function::Function()
     : declaration_(Instruction{spv::Op::OpNop, {}}),
@@ -47,6 +56,17 @@
     cb(inst);
   }
 
+  bool needs_terminator = false;
+  if (instructions_.empty()) {
+    needs_terminator = true;
+  } else {
+    const auto& last = instructions_.back();
+    needs_terminator = !OpIsFunctionTerminator(last.opcode());
+  }
+  if (needs_terminator) {
+    cb(Instruction{spv::Op::OpReturn, {}});
+  }
+
   cb(Instruction{spv::Op::OpFunctionEnd, {}});
 }
 
diff --git a/test/compute_boids.wgsl b/test/compute_boids.wgsl
index 1d32673..6bee6d3 100644
--- a/test/compute_boids.wgsl
+++ b/test/compute_boids.wgsl
@@ -26,7 +26,6 @@
       (a_pos.x * cos(angle)) - (a_pos.y * sin(angle)),
       (a_pos.x * sin(angle)) + (a_pos.y * cos(angle)));
   gl_Position = vec4<f32>(pos + a_particlePos, 0.0, 1.0);
-  return;
 }
 
 # fragment shader
@@ -35,7 +34,6 @@
 [[stage(fragment)]]
 fn frag_main() -> void {
   fragColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
-  return;
 }
 
 # compute shader
@@ -137,6 +135,4 @@
   # Write back
   particlesB.particles[index].pos = vPos;
   particlesB.particles[index].vel = vVel;
-
-  return;
 }
diff --git a/test/cube.wgsl b/test/cube.wgsl
index eb67254..d56f32f 100644
--- a/test/cube.wgsl
+++ b/test/cube.wgsl
@@ -28,7 +28,6 @@
 fn vtx_main() -> void {
    Position = uniforms.modelViewProjectionMatrix * cur_position;
    vtxFragColor = color;
-   return;
 }
 
 # Fragment shader
@@ -38,5 +37,4 @@
 [[stage(fragment)]]
 fn frag_main() -> void {
   outColor = fragColor;
-  return;
 }
diff --git a/test/function.wgsl b/test/function.wgsl
index f2b2e22..604b8ff 100644
--- a/test/function.wgsl
+++ b/test/function.wgsl
@@ -19,5 +19,4 @@
 [[stage(compute)]]
 [[workgroup_size(2)]]
 fn ep() -> void {
-  return;
 }
diff --git a/test/simple.wgsl b/test/simple.wgsl
index ff5f38e..844bfcb 100644
--- a/test/simple.wgsl
+++ b/test/simple.wgsl
@@ -15,7 +15,6 @@
 [[location(0)]] var<out> gl_FragColor : vec4<f32>;
 
 fn bar() -> void {
-  return;
 }
 
 [[stage(fragment)]]
@@ -23,5 +22,4 @@
     var a : vec2<f32> = vec2<f32>();
     gl_FragColor = vec4<f32>(0.4, 0.4, 0.8, 1.0);
     bar();
-    return;
 }
diff --git a/test/triangle.wgsl b/test/triangle.wgsl
index 0159eb6..d612c94 100644
--- a/test/triangle.wgsl
+++ b/test/triangle.wgsl
@@ -24,7 +24,6 @@
 [[stage(vertex)]]
 fn vtx_main() -> void {
   Position = vec4<f32>(pos[VertexIndex], 0.0, 1.0);
-  return;
 }
 
 # Fragment shader
@@ -33,5 +32,4 @@
 [[stage(fragment)]]
 fn frag_main() -> void {
   outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
-  return;
 }