[spirv-writer] Handle non-struct entry point return values
Generate a global variable for the return value and replace return
statements with assignments to this variable.
Add a list of return statements to semantic::Function.
Bug: tint:509
Change-Id: I6bc08fcac7858b48f0eff62199d5011665284220
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44804
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/clone_context.h b/src/clone_context.h
index beb7e15..2853ad4 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -381,12 +381,12 @@
using CloneableList = std::vector<Cloneable*>;
/// A map of object in #src to their cloned equivalent in #dst
- std::unordered_map<Cloneable*, Cloneable*> cloned_;
+ std::unordered_map<const Cloneable*, Cloneable*> cloned_;
/// A map of object in #src to the list of cloned objects in #dst.
/// Clone(const std::vector<T*>& v) will use this to insert the map-value list
/// into the target vector/ before cloning and inserting the map-key.
- std::unordered_map<Cloneable*, CloneableList> insert_before_;
+ std::unordered_map<const Cloneable*, CloneableList> insert_before_;
/// Cloneable transform functions registered with ReplaceAll()
std::vector<CloneableTransform> transforms_;
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 9198c39..c5711bf 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -321,6 +321,7 @@
});
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
+ current_function_->return_statements.push_back(r);
return Expression(r->value());
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
@@ -1215,7 +1216,7 @@
auto* sem_func = builder_->create<semantic::Function>(
info->declaration, remap_vars(info->referenced_module_vars),
- remap_vars(info->local_referenced_module_vars),
+ remap_vars(info->local_referenced_module_vars), info->return_statements,
ancestor_entry_points[func->symbol()]);
func_info_to_sem_func.emplace(info, sem_func);
sem.Add(func, sem_func);
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index d8a3d81..8fd6900 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -38,6 +38,7 @@
class Function;
class IdentifierExpression;
class MemberAccessorExpression;
+class ReturnStatement;
class UnaryOpExpression;
class Variable;
} // namespace ast
@@ -92,6 +93,7 @@
ast::Function* const declaration;
UniqueVector<VariableInfo*> referenced_module_vars;
UniqueVector<VariableInfo*> local_referenced_module_vars;
+ std::vector<const ast::ReturnStatement*> return_statements;
// List of transitive calls this function makes
UniqueVector<FunctionInfo*> transitive_calls;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc
index b85dcb8..bcaa8bc 100644
--- a/src/resolver/resolver_test.cc
+++ b/src/resolver/resolver_test.cc
@@ -858,6 +858,30 @@
EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u);
}
+TEST_F(ResolverTest, Function_ReturnStatements) {
+ auto* var = Var("foo", ty.f32(), ast::StorageClass::kFunction);
+
+ auto* ret_1 = create<ast::ReturnStatement>(Expr(1.f));
+ auto* ret_foo = create<ast::ReturnStatement>(Expr("foo"));
+
+ auto* func = Func("my_func", ast::VariableList{}, ty.f32(),
+ ast::StatementList{
+ create<ast::VariableDeclStatement>(var),
+ If(Expr(true), Block(ret_1)),
+ ret_foo,
+ },
+ ast::DecorationList{});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* func_sem = Sem().Get(func);
+ ASSERT_NE(func_sem, nullptr);
+
+ EXPECT_EQ(func_sem->ReturnStatements().size(), 2u);
+ EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1);
+ EXPECT_EQ(func_sem->ReturnStatements()[1], ret_foo);
+}
+
TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
auto* strct = create<ast::Struct>(
ast::StructMemberList{Member("first_member", ty.i32()),
diff --git a/src/semantic/function.h b/src/semantic/function.h
index 7cff766..fbb5934 100644
--- a/src/semantic/function.h
+++ b/src/semantic/function.h
@@ -29,6 +29,7 @@
class Function;
class GroupDecoration;
class LocationDecoration;
+class ReturnStatement;
} // namespace ast
namespace semantic {
@@ -53,11 +54,13 @@
/// @param declaration the ast::Function
/// @param referenced_module_vars the referenced module variables
/// @param local_referenced_module_vars the locally referenced module
+ /// @param return_statements the function return statements
/// variables
/// @param ancestor_entry_points the ancestor entry points
Function(ast::Function* declaration,
std::vector<const Variable*> referenced_module_vars,
std::vector<const Variable*> local_referenced_module_vars,
+ std::vector<const ast::ReturnStatement*> return_statements,
std::vector<Symbol> ancestor_entry_points);
/// Destructor
@@ -76,6 +79,10 @@
const std::vector<const Variable*>& LocalReferencedModuleVariables() const {
return local_referenced_module_vars_;
}
+ /// @returns the return statements
+ const std::vector<const ast::ReturnStatement*> ReturnStatements() const {
+ return return_statements_;
+ }
/// @returns the ancestor entry points
const std::vector<Symbol>& AncestorEntryPoints() const {
return ancestor_entry_points_;
@@ -148,6 +155,7 @@
ast::Function* const declaration_;
std::vector<const Variable*> const referenced_module_vars_;
std::vector<const Variable*> const local_referenced_module_vars_;
+ std::vector<const ast::ReturnStatement*> const return_statements_;
std::vector<Symbol> const ancestor_entry_points_;
};
diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc
index 4ae8885..15e0b65 100644
--- a/src/semantic/sem_function.cc
+++ b/src/semantic/sem_function.cc
@@ -57,11 +57,13 @@
Function::Function(ast::Function* declaration,
std::vector<const Variable*> referenced_module_vars,
std::vector<const Variable*> local_referenced_module_vars,
+ std::vector<const ast::ReturnStatement*> return_statements,
std::vector<Symbol> ancestor_entry_points)
: Base(declaration->return_type(), GetParameters(declaration)),
declaration_(declaration),
referenced_module_vars_(std::move(referenced_module_vars)),
local_referenced_module_vars_(std::move(local_referenced_module_vars)),
+ return_statements_(std::move(return_statements)),
ancestor_entry_points_(std::move(ancestor_entry_points)) {}
Function::~Function() = default;
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc
index 97ce585..d404d4a 100644
--- a/src/transform/spirv.cc
+++ b/src/transform/spirv.cc
@@ -17,7 +17,9 @@
#include <string>
#include <utility>
+#include "src/ast/return_statement.h"
#include "src/program_builder.h"
+#include "src/semantic/function.h"
#include "src/semantic/variable.h"
namespace tint {
@@ -102,6 +104,8 @@
continue;
}
+ auto* sem_func = ctx.src->Sem().Get(func);
+
for (auto* param : func->params()) {
// TODO(jrprice): Handle structures by moving the declaration and
// construction to the function body.
@@ -126,21 +130,37 @@
}
}
- // TODO(jrprice): Hoist the return type out to a global variable, and
- // replace return statements with variable assignments.
if (!func->return_type()->Is<type::Void>()) {
- TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
- << "entry point return values are not yet supported";
- continue;
+ // TODO(jrprice): Handle structures by creating a variable for each member
+ // and replacing return statements with extracts+stores.
+ if (func->return_type()->UnwrapAll()->Is<type::Struct>()) {
+ TINT_UNIMPLEMENTED(ctx.dst->Diagnostics())
+ << "structures as entry point return values are not yet supported";
+ continue;
+ }
+
+ // Create a new symbol for the global variable.
+ auto var_symbol = ctx.dst->Symbols().New();
+ // Create the global variable.
+ auto* var = ctx.dst->Var(var_symbol, ctx.Clone(func->return_type()),
+ ast::StorageClass::kOutput, nullptr,
+ ctx.Clone(func->return_type_decorations()));
+ ctx.InsertBefore(func, var);
+
+ // Replace all return statements with stores to the global variable.
+ for (auto* ret : sem_func->ReturnStatements()) {
+ ctx.InsertBefore(
+ ret, ctx.dst->create<ast::AssignmentStatement>(
+ ctx.dst->Expr(var_symbol), ctx.Clone(ret->value())));
+ ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
+ }
}
- // Rewrite the function header to remove the parameters.
- // TODO(jrprice): Change return type to void when return values are handled.
+ // Rewrite the function header to remove the parameters and return value.
auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
- ctx.Clone(func->return_type()), ctx.Clone(func->body()),
- ctx.Clone(func->decorations()),
- ctx.Clone(func->return_type_decorations()));
+ ctx.dst->ty.void_(), ctx.Clone(func->body()),
+ ctx.Clone(func->decorations()), ast::DecorationList{});
ctx.Replace(func, new_func);
}
}
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc
index 2f03ad4..2a22b67 100644
--- a/src/transform/spirv_test.cc
+++ b/src/transform/spirv_test.cc
@@ -86,6 +86,97 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnBuiltin) {
+ auto* src = R"(
+[[stage(vertex)]]
+fn vert_main() -> [[builtin(position)]] vec4<f32> {
+ return vec4<f32>(1.0, 2.0, 3.0, 0.0);
+}
+)";
+
+ auto* expect = R"(
+[[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>;
+
+[[stage(vertex)]]
+fn vert_main() -> void {
+ tint_symbol_1 = vec4<f32>(1.0, 2.0, 3.0, 0.0);
+ return;
+}
+)";
+
+ auto got = Run<Spirv>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation) {
+ auto* src = R"(
+[[stage(fragment)]]
+fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
+ if (loc_in > 10u) {
+ return 0.5;
+ }
+ return 1.0;
+}
+)";
+
+ auto* expect = R"(
+[[location(0)]] var<in> tint_symbol_1 : u32;
+
+[[location(0)]] var<out> tint_symbol_2 : f32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+ if ((tint_symbol_1 > 10u)) {
+ tint_symbol_2 = 0.5;
+ return;
+ }
+ tint_symbol_2 = 1.0;
+ return;
+}
+)";
+
+ auto got = Run<Spirv>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
+TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation_TypeAlias) {
+ auto* src = R"(
+type myf32 = f32;
+
+[[stage(fragment)]]
+fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] myf32 {
+ if (loc_in > 10u) {
+ return 0.5;
+ }
+ return 1.0;
+}
+)";
+
+ auto* expect = R"(
+type myf32 = f32;
+
+[[location(0)]] var<in> tint_symbol_1 : u32;
+
+[[location(0)]] var<out> tint_symbol_2 : myf32;
+
+[[stage(fragment)]]
+fn frag_main() -> void {
+ if ((tint_symbol_1 > 10u)) {
+ tint_symbol_2 = 0.5;
+ return;
+ }
+ tint_symbol_2 = 1.0;
+ return;
+}
+)";
+
+ auto got = Run<Spirv>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) {
auto* src = R"(
[[builtin(sample_index)]] var<in> sample_index : u32;
@@ -164,27 +255,26 @@
// Test that different transforms within the sanitizer interact correctly.
TEST_F(SpirvTest, MultipleTransforms) {
- // TODO(jrprice): Make `mask_out` a return value when supported.
auto* src = R"(
-[[builtin(sample_mask_out)]] var<out> mask_out : u32;
-
[[stage(fragment)]]
fn main([[builtin(sample_index)]] sample_index : u32,
- [[builtin(sample_mask_in)]] mask_in : u32) -> void {
- mask_out = mask_in;
+ [[builtin(sample_mask_in)]] mask_in : u32)
+ -> [[builtin(sample_mask_out)]] u32 {
+ return mask_in;
}
)";
auto* expect = R"(
-[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>;
-
[[builtin(sample_index)]] var<in> tint_symbol_1 : u32;
[[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>;
+[[builtin(sample_mask_out)]] var<out> tint_symbol_3 : array<u32, 1>;
+
[[stage(fragment)]]
fn main() -> void {
- mask_out[0] = tint_symbol_2[0];
+ tint_symbol_3[0] = tint_symbol_2[0];
+ return;
}
)";
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc
index a0c52c3..a45e506 100644
--- a/src/writer/spirv/builder_entry_point_test.cc
+++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -18,6 +18,7 @@
#include "src/ast/builtin.h"
#include "src/ast/builtin_decoration.h"
#include "src/ast/location_decoration.h"
+#include "src/ast/return_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/storage_class.h"
#include "src/ast/variable.h"
@@ -96,6 +97,74 @@
)");
}
+TEST_F(BuilderTest, EntryPoint_ReturnValue) {
+ // [[stage(fragment)]]
+ // fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 {
+ // if (loc_in > 10) {
+ // return 0.5;
+ // }
+ // return 1.0;
+ // }
+ auto* f32 = ty.f32();
+ auto* u32 = ty.u32();
+ auto* loc_in = Var("loc_in", u32, ast::StorageClass::kFunction, nullptr,
+ {create<ast::LocationDecoration>(0)});
+ auto* cond = create<ast::BinaryExpression>(ast::BinaryOp::kGreaterThan,
+ Expr("loc_in"), Expr(10u));
+ Func("frag_main", ast::VariableList{loc_in}, f32,
+ ast::StatementList{
+ If(cond, Block(create<ast::ReturnStatement>(Expr(0.5f)))),
+ create<ast::ReturnStatement>(Expr(1.0f)),
+ },
+ ast::DecorationList{
+ create<ast::StageDecoration>(ast::PipelineStage::kFragment),
+ },
+ ast::DecorationList{create<ast::LocationDecoration>(0)});
+
+ spirv::Builder& b = SanitizeAndBuild();
+
+ ASSERT_TRUE(b.Build());
+
+ // Test that the return value gets hoisted out to a global variable with the
+ // Output storage class, and the return statements are replaced with stores.
+ EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %10 "frag_main" %1 %4
+OpExecutionMode %10 OriginUpperLeft
+OpName %1 "tint_symbol_1"
+OpName %4 "tint_symbol_2"
+OpName %10 "frag_main"
+OpDecorate %1 Location 0
+OpDecorate %4 Location 0
+%3 = OpTypeInt 32 0
+%2 = OpTypePointer Input %3
+%1 = OpVariable %2 Input
+%6 = OpTypeFloat 32
+%5 = OpTypePointer Output %6
+%7 = OpConstantNull %6
+%4 = OpVariable %5 Output %7
+%9 = OpTypeVoid
+%8 = OpTypeFunction %9
+%13 = OpConstant %3 10
+%15 = OpTypeBool
+%18 = OpConstant %6 0.5
+%19 = OpConstant %6 1
+%10 = OpFunction %9 None %8
+%11 = OpLabel
+%12 = OpLoad %3 %1
+%14 = OpUGreaterThan %15 %12 %13
+OpSelectionMerge %16 None
+OpBranchConditional %14 %17 %16
+%17 = OpLabel
+OpStore %4 %18
+OpReturn
+%16 = OpLabel
+OpStore %4 %19
+OpReturn
+OpFunctionEnd
+)");
+}
+
} // namespace
} // namespace spirv
} // namespace writer