[spirv-writer] Support optional trailing return.
This CL updates the SPIRV-Writer to inject an OpReturn as the trailing
statement in a function if the function does not end with a `discard` or
a `return` statement.
R=bclayton@google.com, dneto@google.com
Fixes: tint:302
Change-Id: I2e7c7beff15ad30c779c591bb75cf97fc0960bf7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33160
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc
index 1ca5c0e..1fc8aa4 100644
--- a/src/writer/spirv/builder_call_test.cc
+++ b/src/writer/spirv/builder_call_test.cc
@@ -95,6 +95,7 @@
%12 = OpFunction %11 None %10
%13 = OpLabel
%14 = OpFunctionCall %2 %3 %15 %15
+OpReturn
OpFunctionEnd
)");
}
@@ -156,6 +157,7 @@
%12 = OpFunction %2 None %11
%13 = OpLabel
%14 = OpFunctionCall %2 %4 %15 %15
+OpReturn
OpFunctionEnd
)");
}
diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc
index a68f780..8b59b2e 100644
--- a/src/writer/spirv/builder_function_decoration_test.cc
+++ b/src/writer/spirv/builder_function_decoration_test.cc
@@ -258,9 +258,11 @@
%1 = OpTypeFunction %2
%3 = OpFunction %2 None %1
%4 = OpLabel
+OpReturn
OpFunctionEnd
%5 = OpFunction %2 None %1
%6 = OpLabel
+OpReturn
OpFunctionEnd
)");
}
diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc
index 528ee18..1a6de92 100644
--- a/src/writer/spirv/builder_function_test.cc
+++ b/src/writer/spirv/builder_function_test.cc
@@ -18,6 +18,7 @@
#include "spirv/unified1/spirv.h"
#include "spirv/unified1/spirv.hpp11"
#include "src/ast/decorated_variable.h"
+#include "src/ast/discard_statement.h"
#include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/member_accessor_expression.h"
@@ -51,15 +52,83 @@
ast::Function func("a_func", {}, &void_type, create<ast::BlockStatement>());
ASSERT_TRUE(b.GenerateFunction(&func));
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "tint_615f66756e63"
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeVoid
+ EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "tint_615f66756e63"
+%2 = OpTypeVoid
%1 = OpTypeFunction %2
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+OpReturn
+OpFunctionEnd
)");
+}
- ASSERT_GE(b.functions().size(), 1u);
- const auto& ret = b.functions()[0];
- EXPECT_EQ(DumpInstruction(ret.declaration()), R"(%3 = OpFunction %2 None %1
+TEST_F(BuilderTest, Function_Terminator_Return) {
+ ast::type::VoidType void_type;
+
+ auto* body = create<ast::BlockStatement>();
+ body->append(create<ast::ReturnStatement>());
+
+ ast::Function func("a_func", {}, &void_type, body);
+
+ ASSERT_TRUE(b.GenerateFunction(&func));
+ EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "tint_615f66756e63"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST_F(BuilderTest, Function_Terminator_ReturnValue) {
+ ast::type::VoidType void_type;
+ ast::type::F32Type f32;
+
+ auto* var_a = create<ast::Variable>("a", ast::StorageClass::kPrivate, &f32);
+ td.RegisterVariableForTesting(var_a);
+
+ auto* body = create<ast::BlockStatement>();
+ body->append(
+ create<ast::ReturnStatement>(create<ast::IdentifierExpression>("a")));
+ ASSERT_TRUE(td.DetermineResultType(body)) << td.error();
+
+ ast::Function func("a_func", {}, &void_type, body);
+
+ ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error();
+ ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
+ EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "tint_61"
+OpName %7 "tint_615f66756e63"
+%3 = OpTypeFloat 32
+%2 = OpTypePointer Private %3
+%4 = OpConstantNull %3
+%1 = OpVariable %2 Private %4
+%6 = OpTypeVoid
+%5 = OpTypeFunction %6
+%7 = OpFunction %6 None %5
+%8 = OpLabel
+%9 = OpLoad %3 %1
+OpReturnValue %9
+OpFunctionEnd
+)");
+}
+
+TEST_F(BuilderTest, Function_Terminator_Discard) {
+ ast::type::VoidType void_type;
+
+ auto* body = create<ast::BlockStatement>();
+ body->append(create<ast::DiscardStatement>());
+
+ ast::Function func("a_func", {}, &void_type, body);
+
+ ASSERT_TRUE(b.GenerateFunction(&func));
+ EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "tint_615f66756e63"
+%2 = OpTypeVoid
+%1 = OpTypeFunction %2
+%3 = OpFunction %2 None %1
+%4 = OpLabel
+OpKill
+OpFunctionEnd
)");
}
diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc
index de8cecf..c877885 100644
--- a/src/writer/spirv/builder_intrinsic_test.cc
+++ b/src/writer/spirv/builder_intrinsic_test.cc
@@ -940,6 +940,7 @@
%8 = OpLabel
%11 = OpLoad %3 %1
%9 = OpExtInst %3 %10 Round %11
+OpReturn
OpFunctionEnd
)");
}
@@ -967,6 +968,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8
+OpReturn
OpFunctionEnd
)");
}
@@ -994,6 +996,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1042,6 +1045,7 @@
%3 = OpFunction %2 None %1
%4 = OpLabel
%5 = OpExtInst %6 %7 Length %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1066,6 +1070,7 @@
%3 = OpFunction %2 None %1
%4 = OpLabel
%5 = OpExtInst %6 %7 Length %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1090,6 +1095,7 @@
%3 = OpFunction %2 None %1
%4 = OpLabel
%5 = OpExtInst %6 %8 Normalize %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1118,6 +1124,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8 %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1146,6 +1153,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10 %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1177,6 +1185,7 @@
%3 = OpFunction %2 None %1
%4 = OpLabel
%5 = OpExtInst %6 %7 Distance %8 %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1202,6 +1211,7 @@
%3 = OpFunction %2 None %1
%4 = OpLabel
%5 = OpExtInst %6 %7 Distance %10 %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1228,6 +1238,7 @@
%3 = OpFunction %2 None %1
%4 = OpLabel
%5 = OpExtInst %6 %8 Cross %10 %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1255,6 +1266,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8 %8 %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1284,6 +1296,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10 %10 %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1320,6 +1333,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1347,6 +1361,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1377,6 +1392,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1404,6 +1420,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1434,6 +1451,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8 %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1461,6 +1479,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10 %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1492,6 +1511,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8 %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1519,6 +1539,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10 %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1550,6 +1571,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8 %8 %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1579,6 +1601,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10 %10 %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1609,6 +1632,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %7 )" + param.op +
R"( %8 %8 %8
+OpReturn
OpFunctionEnd
)");
}
@@ -1638,6 +1662,7 @@
%4 = OpLabel
%5 = OpExtInst %6 %8 )" + param.op +
R"( %10 %10 %10
+OpReturn
OpFunctionEnd
)");
}
@@ -1673,6 +1698,7 @@
%4 = OpLabel
%13 = OpLoad %7 %5
%11 = OpExtInst %9 %12 Determinant %13
+OpReturn
OpFunctionEnd
)");
}
diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc
index 2738b28..b4b110b 100644
--- a/src/writer/spirv/builder_switch_test.cc
+++ b/src/writer/spirv/builder_switch_test.cc
@@ -141,6 +141,7 @@
%11 = OpLabel
OpBranch %9
%9 = OpLabel
+OpReturn
OpFunctionEnd
)");
}
@@ -198,6 +199,7 @@
OpStore %1 %12
OpBranch %9
%9 = OpLabel
+OpReturn
OpFunctionEnd
)");
}
@@ -288,6 +290,7 @@
OpStore %1 %16
OpBranch %9
%9 = OpLabel
+OpReturn
OpFunctionEnd
)");
}
@@ -379,6 +382,7 @@
OpStore %1 %16
OpBranch %9
%9 = OpLabel
+OpReturn
OpFunctionEnd
)");
}
@@ -501,6 +505,7 @@
%11 = OpLabel
OpBranch %9
%9 = OpLabel
+OpReturn
OpFunctionEnd
)");
}
diff --git a/src/writer/spirv/function.cc b/src/writer/spirv/function.cc
index 01da96c..5c5c726 100644
--- a/src/writer/spirv/function.cc
+++ b/src/writer/spirv/function.cc
@@ -17,6 +17,15 @@
namespace tint {
namespace writer {
namespace spirv {
+namespace {
+
+// Returns true if the given Op is a function terminator
+bool OpIsFunctionTerminator(spv::Op op) {
+ return op == spv::Op::OpReturn || op == spv::Op::OpReturnValue ||
+ op == spv::Op::OpKill;
+}
+
+} // namespace
Function::Function()
: declaration_(Instruction{spv::Op::OpNop, {}}),
@@ -47,6 +56,17 @@
cb(inst);
}
+ bool needs_terminator = false;
+ if (instructions_.empty()) {
+ needs_terminator = true;
+ } else {
+ const auto& last = instructions_.back();
+ needs_terminator = !OpIsFunctionTerminator(last.opcode());
+ }
+ if (needs_terminator) {
+ cb(Instruction{spv::Op::OpReturn, {}});
+ }
+
cb(Instruction{spv::Op::OpFunctionEnd, {}});
}
diff --git a/test/compute_boids.wgsl b/test/compute_boids.wgsl
index 1d32673..6bee6d3 100644
--- a/test/compute_boids.wgsl
+++ b/test/compute_boids.wgsl
@@ -26,7 +26,6 @@
(a_pos.x * cos(angle)) - (a_pos.y * sin(angle)),
(a_pos.x * sin(angle)) + (a_pos.y * cos(angle)));
gl_Position = vec4<f32>(pos + a_particlePos, 0.0, 1.0);
- return;
}
# fragment shader
@@ -35,7 +34,6 @@
[[stage(fragment)]]
fn frag_main() -> void {
fragColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
- return;
}
# compute shader
@@ -137,6 +135,4 @@
# Write back
particlesB.particles[index].pos = vPos;
particlesB.particles[index].vel = vVel;
-
- return;
}
diff --git a/test/cube.wgsl b/test/cube.wgsl
index eb67254..d56f32f 100644
--- a/test/cube.wgsl
+++ b/test/cube.wgsl
@@ -28,7 +28,6 @@
fn vtx_main() -> void {
Position = uniforms.modelViewProjectionMatrix * cur_position;
vtxFragColor = color;
- return;
}
# Fragment shader
@@ -38,5 +37,4 @@
[[stage(fragment)]]
fn frag_main() -> void {
outColor = fragColor;
- return;
}
diff --git a/test/function.wgsl b/test/function.wgsl
index f2b2e22..604b8ff 100644
--- a/test/function.wgsl
+++ b/test/function.wgsl
@@ -19,5 +19,4 @@
[[stage(compute)]]
[[workgroup_size(2)]]
fn ep() -> void {
- return;
}
diff --git a/test/simple.wgsl b/test/simple.wgsl
index ff5f38e..844bfcb 100644
--- a/test/simple.wgsl
+++ b/test/simple.wgsl
@@ -15,7 +15,6 @@
[[location(0)]] var<out> gl_FragColor : vec4<f32>;
fn bar() -> void {
- return;
}
[[stage(fragment)]]
@@ -23,5 +22,4 @@
var a : vec2<f32> = vec2<f32>();
gl_FragColor = vec4<f32>(0.4, 0.4, 0.8, 1.0);
bar();
- return;
}
diff --git a/test/triangle.wgsl b/test/triangle.wgsl
index 0159eb6..d612c94 100644
--- a/test/triangle.wgsl
+++ b/test/triangle.wgsl
@@ -24,7 +24,6 @@
[[stage(vertex)]]
fn vtx_main() -> void {
Position = vec4<f32>(pos[VertexIndex], 0.0, 1.0);
- return;
}
# Fragment shader
@@ -33,5 +32,4 @@
[[stage(fragment)]]
fn frag_main() -> void {
outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
- return;
}