[type-determiner] Update to work with entry point and function stages. This Cl updates the type determiner to work with both styles of entry point definition. Change-Id: Ic48f1a5f0a5820821f9a74380896426a97483049 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28666 Commit-Queue: dan sinclair <dsinclair@chromium.org> Reviewed-by: David Neto <dneto@google.com> Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc index 8e8d853..23bd004 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc
@@ -17,6 +17,7 @@ #include <sstream> #include "src/ast/decorated_variable.h" +#include "src/ast/stage_decoration.h" #include "src/ast/workgroup_decoration.h" namespace tint { @@ -56,6 +57,15 @@ return {1, 1, 1}; } +ast::PipelineStage Function::pipeline_stage() const { + for (const auto& deco : decorations_) { + if (deco->IsStage()) { + return deco->AsStage()->value(); + } + } + return ast::PipelineStage::kNone; +} + void Function::add_referenced_module_variable(Variable* var) { for (const auto* v : referenced_module_vars_) { if (v->name() == var->name()) {
diff --git a/src/ast/function.h b/src/ast/function.h index 078f692..5cb75ba 100644 --- a/src/ast/function.h +++ b/src/ast/function.h
@@ -28,6 +28,7 @@ #include "src/ast/function_decoration.h" #include "src/ast/location_decoration.h" #include "src/ast/node.h" +#include "src/ast/pipeline_stage.h" #include "src/ast/set_decoration.h" #include "src/ast/statement.h" #include "src/ast/type/type.h" @@ -100,6 +101,9 @@ /// return if no workgroup size was set. std::tuple<uint32_t, uint32_t, uint32_t> workgroup_size() const; + /// @returns the functions pipeline stage or None if not set + ast::PipelineStage pipeline_stage() const; + /// Adds the given variable to the list of referenced module variables if it /// is not already included. /// @param var the module variable to add
diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 948b04c..a55ec6c 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc
@@ -217,6 +217,17 @@ } } + // Walk over the caller to callee information and update functions with which + // entry points call those functions. + for (const auto& func : mod_->functions()) { + if (func->pipeline_stage() == ast::PipelineStage::kNone) { + continue; + } + for (const auto& callee : caller_to_callee_[func->name()]) { + set_entry_points(callee, func->name()); + } + } + return true; }
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 512e0b2..67821ad 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc
@@ -37,9 +37,11 @@ #include "src/ast/if_statement.h" #include "src/ast/loop_statement.h" #include "src/ast/member_accessor_expression.h" +#include "src/ast/pipeline_stage.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/sint_literal.h" +#include "src/ast/stage_decoration.h" #include "src/ast/struct.h" #include "src/ast/struct_member.h" #include "src/ast/switch_statement.h" @@ -4479,5 +4481,107 @@ EXPECT_TRUE(ep_2_func_ptr->ancestor_entry_points().empty()); } +TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { + ast::type::F32Type f32; + + // fn b() {} + // fn c() { b(); } + // fn a() { c(); } + // fn ep_1() { a(); b(); } + // fn ep_2() { c();} + // + // c -> {ep_1, ep_2} + // a -> {ep_1} + // b -> {ep_1, ep_2} + // ep_1 -> {} + // ep_2 -> {} + + ast::VariableList params; + auto func_b = std::make_unique<ast::Function>("b", std::move(params), &f32); + auto* func_b_ptr = func_b.get(); + + auto body = std::make_unique<ast::BlockStatement>(); + func_b->set_body(std::move(body)); + + auto func_c = std::make_unique<ast::Function>("c", std::move(params), &f32); + auto* func_c_ptr = func_c.get(); + + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("second"), + std::make_unique<ast::CallExpression>( + std::make_unique<ast::IdentifierExpression>("b"), + ast::ExpressionList{}))); + func_c->set_body(std::move(body)); + + auto func_a = std::make_unique<ast::Function>("a", std::move(params), &f32); + auto* func_a_ptr = func_a.get(); + + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("first"), + std::make_unique<ast::CallExpression>( + std::make_unique<ast::IdentifierExpression>("c"), + ast::ExpressionList{}))); + func_a->set_body(std::move(body)); + + auto ep_1 = std::make_unique<ast::Function>("ep_1", std::move(params), &f32); + ep_1->add_decoration( + std::make_unique<ast::StageDecoration>(ast::PipelineStage::kVertex)); + auto* ep_1_ptr = ep_1.get(); + + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("call_a"), + std::make_unique<ast::CallExpression>( + std::make_unique<ast::IdentifierExpression>("a"), + ast::ExpressionList{}))); + body->append(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("call_b"), + std::make_unique<ast::CallExpression>( + std::make_unique<ast::IdentifierExpression>("b"), + ast::ExpressionList{}))); + ep_1->set_body(std::move(body)); + + auto ep_2 = std::make_unique<ast::Function>("ep_2", std::move(params), &f32); + ep_2->add_decoration( + std::make_unique<ast::StageDecoration>(ast::PipelineStage::kVertex)); + auto* ep_2_ptr = ep_2.get(); + + body = std::make_unique<ast::BlockStatement>(); + body->append(std::make_unique<ast::AssignmentStatement>( + std::make_unique<ast::IdentifierExpression>("call_c"), + std::make_unique<ast::CallExpression>( + std::make_unique<ast::IdentifierExpression>("c"), + ast::ExpressionList{}))); + ep_2->set_body(std::move(body)); + + mod()->AddFunction(std::move(func_b)); + mod()->AddFunction(std::move(func_c)); + mod()->AddFunction(std::move(func_a)); + mod()->AddFunction(std::move(ep_1)); + mod()->AddFunction(std::move(ep_2)); + + // Register the functions and calculate the callers + ASSERT_TRUE(td()->Determine()) << td()->error(); + + const auto& b_eps = func_b_ptr->ancestor_entry_points(); + ASSERT_EQ(2u, b_eps.size()); + EXPECT_EQ("ep_1", b_eps[0]); + EXPECT_EQ("ep_2", b_eps[1]); + + const auto& a_eps = func_a_ptr->ancestor_entry_points(); + ASSERT_EQ(1u, a_eps.size()); + EXPECT_EQ("ep_1", a_eps[0]); + + const auto& c_eps = func_c_ptr->ancestor_entry_points(); + ASSERT_EQ(2u, c_eps.size()); + EXPECT_EQ("ep_1", c_eps[0]); + EXPECT_EQ("ep_2", c_eps[1]); + + EXPECT_TRUE(ep_1_ptr->ancestor_entry_points().empty()); + EXPECT_TRUE(ep_2_ptr->ancestor_entry_points().empty()); +} + } // namespace } // namespace tint