[spirv-writer] Handle non-struct entry point return values Generate a global variable for the return value and replace return statements with assignments to this variable. Add a list of return statements to semantic::Function. Bug: tint:509 Change-Id: I6bc08fcac7858b48f0eff62199d5011665284220 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44804 Commit-Queue: James Price <jrprice@google.com> Auto-Submit: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/clone_context.h b/src/clone_context.h index beb7e15..2853ad4 100644 --- a/src/clone_context.h +++ b/src/clone_context.h
@@ -381,12 +381,12 @@ using CloneableList = std::vector<Cloneable*>; /// A map of object in #src to their cloned equivalent in #dst - std::unordered_map<Cloneable*, Cloneable*> cloned_; + std::unordered_map<const Cloneable*, Cloneable*> cloned_; /// A map of object in #src to the list of cloned objects in #dst. /// Clone(const std::vector<T*>& v) will use this to insert the map-value list /// into the target vector/ before cloning and inserting the map-key. - std::unordered_map<Cloneable*, CloneableList> insert_before_; + std::unordered_map<const Cloneable*, CloneableList> insert_before_; /// Cloneable transform functions registered with ReplaceAll() std::vector<CloneableTransform> transforms_;
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 9198c39..c5711bf 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc
@@ -321,6 +321,7 @@ }); } if (auto* r = stmt->As<ast::ReturnStatement>()) { + current_function_->return_statements.push_back(r); return Expression(r->value()); } if (auto* s = stmt->As<ast::SwitchStatement>()) { @@ -1215,7 +1216,7 @@ auto* sem_func = builder_->create<semantic::Function>( info->declaration, remap_vars(info->referenced_module_vars), - remap_vars(info->local_referenced_module_vars), + remap_vars(info->local_referenced_module_vars), info->return_statements, ancestor_entry_points[func->symbol()]); func_info_to_sem_func.emplace(info, sem_func); sem.Add(func, sem_func);
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index d8a3d81..8fd6900 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h
@@ -38,6 +38,7 @@ class Function; class IdentifierExpression; class MemberAccessorExpression; +class ReturnStatement; class UnaryOpExpression; class Variable; } // namespace ast @@ -92,6 +93,7 @@ ast::Function* const declaration; UniqueVector<VariableInfo*> referenced_module_vars; UniqueVector<VariableInfo*> local_referenced_module_vars; + std::vector<const ast::ReturnStatement*> return_statements; // List of transitive calls this function makes UniqueVector<FunctionInfo*> transitive_calls;
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index b85dcb8..bcaa8bc 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc
@@ -858,6 +858,30 @@ EXPECT_EQ(func_sem->ReferencedModuleVariables().size(), 0u); } +TEST_F(ResolverTest, Function_ReturnStatements) { + auto* var = Var("foo", ty.f32(), ast::StorageClass::kFunction); + + auto* ret_1 = create<ast::ReturnStatement>(Expr(1.f)); + auto* ret_foo = create<ast::ReturnStatement>(Expr("foo")); + + auto* func = Func("my_func", ast::VariableList{}, ty.f32(), + ast::StatementList{ + create<ast::VariableDeclStatement>(var), + If(Expr(true), Block(ret_1)), + ret_foo, + }, + ast::DecorationList{}); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* func_sem = Sem().Get(func); + ASSERT_NE(func_sem, nullptr); + + EXPECT_EQ(func_sem->ReturnStatements().size(), 2u); + EXPECT_EQ(func_sem->ReturnStatements()[0], ret_1); + EXPECT_EQ(func_sem->ReturnStatements()[1], ret_foo); +} + TEST_F(ResolverTest, Expr_MemberAccessor_Struct) { auto* strct = create<ast::Struct>( ast::StructMemberList{Member("first_member", ty.i32()),
diff --git a/src/semantic/function.h b/src/semantic/function.h index 7cff766..fbb5934 100644 --- a/src/semantic/function.h +++ b/src/semantic/function.h
@@ -29,6 +29,7 @@ class Function; class GroupDecoration; class LocationDecoration; +class ReturnStatement; } // namespace ast namespace semantic { @@ -53,11 +54,13 @@ /// @param declaration the ast::Function /// @param referenced_module_vars the referenced module variables /// @param local_referenced_module_vars the locally referenced module + /// @param return_statements the function return statements /// variables /// @param ancestor_entry_points the ancestor entry points Function(ast::Function* declaration, std::vector<const Variable*> referenced_module_vars, std::vector<const Variable*> local_referenced_module_vars, + std::vector<const ast::ReturnStatement*> return_statements, std::vector<Symbol> ancestor_entry_points); /// Destructor @@ -76,6 +79,10 @@ const std::vector<const Variable*>& LocalReferencedModuleVariables() const { return local_referenced_module_vars_; } + /// @returns the return statements + const std::vector<const ast::ReturnStatement*> ReturnStatements() const { + return return_statements_; + } /// @returns the ancestor entry points const std::vector<Symbol>& AncestorEntryPoints() const { return ancestor_entry_points_; @@ -148,6 +155,7 @@ ast::Function* const declaration_; std::vector<const Variable*> const referenced_module_vars_; std::vector<const Variable*> const local_referenced_module_vars_; + std::vector<const ast::ReturnStatement*> const return_statements_; std::vector<Symbol> const ancestor_entry_points_; };
diff --git a/src/semantic/sem_function.cc b/src/semantic/sem_function.cc index 4ae8885..15e0b65 100644 --- a/src/semantic/sem_function.cc +++ b/src/semantic/sem_function.cc
@@ -57,11 +57,13 @@ Function::Function(ast::Function* declaration, std::vector<const Variable*> referenced_module_vars, std::vector<const Variable*> local_referenced_module_vars, + std::vector<const ast::ReturnStatement*> return_statements, std::vector<Symbol> ancestor_entry_points) : Base(declaration->return_type(), GetParameters(declaration)), declaration_(declaration), referenced_module_vars_(std::move(referenced_module_vars)), local_referenced_module_vars_(std::move(local_referenced_module_vars)), + return_statements_(std::move(return_statements)), ancestor_entry_points_(std::move(ancestor_entry_points)) {} Function::~Function() = default;
diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index 97ce585..d404d4a 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc
@@ -17,7 +17,9 @@ #include <string> #include <utility> +#include "src/ast/return_statement.h" #include "src/program_builder.h" +#include "src/semantic/function.h" #include "src/semantic/variable.h" namespace tint { @@ -102,6 +104,8 @@ continue; } + auto* sem_func = ctx.src->Sem().Get(func); + for (auto* param : func->params()) { // TODO(jrprice): Handle structures by moving the declaration and // construction to the function body. @@ -126,21 +130,37 @@ } } - // TODO(jrprice): Hoist the return type out to a global variable, and - // replace return statements with variable assignments. if (!func->return_type()->Is<type::Void>()) { - TINT_UNIMPLEMENTED(ctx.dst->Diagnostics()) - << "entry point return values are not yet supported"; - continue; + // TODO(jrprice): Handle structures by creating a variable for each member + // and replacing return statements with extracts+stores. + if (func->return_type()->UnwrapAll()->Is<type::Struct>()) { + TINT_UNIMPLEMENTED(ctx.dst->Diagnostics()) + << "structures as entry point return values are not yet supported"; + continue; + } + + // Create a new symbol for the global variable. + auto var_symbol = ctx.dst->Symbols().New(); + // Create the global variable. + auto* var = ctx.dst->Var(var_symbol, ctx.Clone(func->return_type()), + ast::StorageClass::kOutput, nullptr, + ctx.Clone(func->return_type_decorations())); + ctx.InsertBefore(func, var); + + // Replace all return statements with stores to the global variable. + for (auto* ret : sem_func->ReturnStatements()) { + ctx.InsertBefore( + ret, ctx.dst->create<ast::AssignmentStatement>( + ctx.dst->Expr(var_symbol), ctx.Clone(ret->value()))); + ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>()); + } } - // Rewrite the function header to remove the parameters. - // TODO(jrprice): Change return type to void when return values are handled. + // Rewrite the function header to remove the parameters and return value. auto* new_func = ctx.dst->create<ast::Function>( func->source(), ctx.Clone(func->symbol()), ast::VariableList{}, - ctx.Clone(func->return_type()), ctx.Clone(func->body()), - ctx.Clone(func->decorations()), - ctx.Clone(func->return_type_decorations())); + ctx.dst->ty.void_(), ctx.Clone(func->body()), + ctx.Clone(func->decorations()), ast::DecorationList{}); ctx.Replace(func, new_func); } }
diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc index 2f03ad4..2a22b67 100644 --- a/src/transform/spirv_test.cc +++ b/src/transform/spirv_test.cc
@@ -86,6 +86,97 @@ EXPECT_EQ(expect, str(got)); } +TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnBuiltin) { + auto* src = R"( +[[stage(vertex)]] +fn vert_main() -> [[builtin(position)]] vec4<f32> { + return vec4<f32>(1.0, 2.0, 3.0, 0.0); +} +)"; + + auto* expect = R"( +[[builtin(position)]] var<out> tint_symbol_1 : vec4<f32>; + +[[stage(vertex)]] +fn vert_main() -> void { + tint_symbol_1 = vec4<f32>(1.0, 2.0, 3.0, 0.0); + return; +} +)"; + + auto got = Run<Spirv>(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation) { + auto* src = R"( +[[stage(fragment)]] +fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 { + if (loc_in > 10u) { + return 0.5; + } + return 1.0; +} +)"; + + auto* expect = R"( +[[location(0)]] var<in> tint_symbol_1 : u32; + +[[location(0)]] var<out> tint_symbol_2 : f32; + +[[stage(fragment)]] +fn frag_main() -> void { + if ((tint_symbol_1 > 10u)) { + tint_symbol_2 = 0.5; + return; + } + tint_symbol_2 = 1.0; + return; +} +)"; + + auto got = Run<Spirv>(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(SpirvTest, HandleEntryPointIOTypes_ReturnLocation_TypeAlias) { + auto* src = R"( +type myf32 = f32; + +[[stage(fragment)]] +fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] myf32 { + if (loc_in > 10u) { + return 0.5; + } + return 1.0; +} +)"; + + auto* expect = R"( +type myf32 = f32; + +[[location(0)]] var<in> tint_symbol_1 : u32; + +[[location(0)]] var<out> tint_symbol_2 : myf32; + +[[stage(fragment)]] +fn frag_main() -> void { + if ((tint_symbol_1 > 10u)) { + tint_symbol_2 = 0.5; + return; + } + tint_symbol_2 = 1.0; + return; +} +)"; + + auto got = Run<Spirv>(src); + + EXPECT_EQ(expect, str(got)); +} + TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) { auto* src = R"( [[builtin(sample_index)]] var<in> sample_index : u32; @@ -164,27 +255,26 @@ // Test that different transforms within the sanitizer interact correctly. TEST_F(SpirvTest, MultipleTransforms) { - // TODO(jrprice): Make `mask_out` a return value when supported. auto* src = R"( -[[builtin(sample_mask_out)]] var<out> mask_out : u32; - [[stage(fragment)]] fn main([[builtin(sample_index)]] sample_index : u32, - [[builtin(sample_mask_in)]] mask_in : u32) -> void { - mask_out = mask_in; + [[builtin(sample_mask_in)]] mask_in : u32) + -> [[builtin(sample_mask_out)]] u32 { + return mask_in; } )"; auto* expect = R"( -[[builtin(sample_mask_out)]] var<out> mask_out : array<u32, 1>; - [[builtin(sample_index)]] var<in> tint_symbol_1 : u32; [[builtin(sample_mask_in)]] var<in> tint_symbol_2 : array<u32, 1>; +[[builtin(sample_mask_out)]] var<out> tint_symbol_3 : array<u32, 1>; + [[stage(fragment)]] fn main() -> void { - mask_out[0] = tint_symbol_2[0]; + tint_symbol_3[0] = tint_symbol_2[0]; + return; } )";
diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc index a0c52c3..a45e506 100644 --- a/src/writer/spirv/builder_entry_point_test.cc +++ b/src/writer/spirv/builder_entry_point_test.cc
@@ -18,6 +18,7 @@ #include "src/ast/builtin.h" #include "src/ast/builtin_decoration.h" #include "src/ast/location_decoration.h" +#include "src/ast/return_statement.h" #include "src/ast/stage_decoration.h" #include "src/ast/storage_class.h" #include "src/ast/variable.h" @@ -96,6 +97,74 @@ )"); } +TEST_F(BuilderTest, EntryPoint_ReturnValue) { + // [[stage(fragment)]] + // fn frag_main([[location(0)]] loc_in : u32) -> [[location(0)]] f32 { + // if (loc_in > 10) { + // return 0.5; + // } + // return 1.0; + // } + auto* f32 = ty.f32(); + auto* u32 = ty.u32(); + auto* loc_in = Var("loc_in", u32, ast::StorageClass::kFunction, nullptr, + {create<ast::LocationDecoration>(0)}); + auto* cond = create<ast::BinaryExpression>(ast::BinaryOp::kGreaterThan, + Expr("loc_in"), Expr(10u)); + Func("frag_main", ast::VariableList{loc_in}, f32, + ast::StatementList{ + If(cond, Block(create<ast::ReturnStatement>(Expr(0.5f)))), + create<ast::ReturnStatement>(Expr(1.0f)), + }, + ast::DecorationList{ + create<ast::StageDecoration>(ast::PipelineStage::kFragment), + }, + ast::DecorationList{create<ast::LocationDecoration>(0)}); + + spirv::Builder& b = SanitizeAndBuild(); + + ASSERT_TRUE(b.Build()); + + // Test that the return value gets hoisted out to a global variable with the + // Output storage class, and the return statements are replaced with stores. + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %10 "frag_main" %1 %4 +OpExecutionMode %10 OriginUpperLeft +OpName %1 "tint_symbol_1" +OpName %4 "tint_symbol_2" +OpName %10 "frag_main" +OpDecorate %1 Location 0 +OpDecorate %4 Location 0 +%3 = OpTypeInt 32 0 +%2 = OpTypePointer Input %3 +%1 = OpVariable %2 Input +%6 = OpTypeFloat 32 +%5 = OpTypePointer Output %6 +%7 = OpConstantNull %6 +%4 = OpVariable %5 Output %7 +%9 = OpTypeVoid +%8 = OpTypeFunction %9 +%13 = OpConstant %3 10 +%15 = OpTypeBool +%18 = OpConstant %6 0.5 +%19 = OpConstant %6 1 +%10 = OpFunction %9 None %8 +%11 = OpLabel +%12 = OpLoad %3 %1 +%14 = OpUGreaterThan %15 %12 %13 +OpSelectionMerge %16 None +OpBranchConditional %14 %17 %16 +%17 = OpLabel +OpStore %4 %18 +OpReturn +%16 = OpLabel +OpStore %4 %19 +OpReturn +OpFunctionEnd +)"); +} + } // namespace } // namespace spirv } // namespace writer