Add determination of entrypoint callees.

This Cl updates the type determiner to annotate each function with the
name of any entry points which call into the given function. This will
allow determining in the backends if we need to duplicate the function
due to differing entry point parameter requirements.

Bug: tint:8
Change-Id: Icd7c4ccab72dd6eabcf0abaf1159319949c4ecf5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24760
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/ast/function.cc b/src/ast/function.cc
index f166031..546a27b 100644
--- a/src/ast/function.cc
+++ b/src/ast/function.cc
@@ -51,6 +51,15 @@
   referenced_module_vars_.push_back(var);
 }
 
+void Function::add_ancestor_entry_point(const std::string& ep) {
+  for (const auto& point : ancestor_entry_points_) {
+    if (point == ep) {
+      return;
+    }
+  }
+  ancestor_entry_points_.push_back(ep);
+}
+
 bool Function::IsValid() const {
   for (const auto& param : params_) {
     if (param == nullptr || !param->IsValid())
diff --git a/src/ast/function.h b/src/ast/function.h
index fa26849..bc804a3 100644
--- a/src/ast/function.h
+++ b/src/ast/function.h
@@ -77,6 +77,14 @@
     return referenced_module_vars_;
   }
 
+  /// Adds an ancestor entry point
+  /// @param ep the entry point ancestor
+  void add_ancestor_entry_point(const std::string& ep);
+  /// @returns the ancestor entry points
+  const std::vector<std::string>& ancestor_entry_points() const {
+    return ancestor_entry_points_;
+  }
+
   /// Sets the return type of the function
   /// @param type the return type
   void set_return_type(type::Type* type) { return_type_ = type; }
@@ -108,6 +116,7 @@
   type::Type* return_type_ = nullptr;
   StatementList body_;
   std::vector<Variable*> referenced_module_vars_;
+  std::vector<std::string> ancestor_entry_points_;
 };
 
 /// A list of unique functions
diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc
index 5b0c394..71222fd 100644
--- a/src/ast/function_test.cc
+++ b/src/ast/function_test.cc
@@ -16,6 +16,7 @@
 
 #include "gtest/gtest.h"
 #include "src/ast/kill_statement.h"
+#include "src/ast/pipeline_stage.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/void_type.h"
@@ -77,6 +78,19 @@
   EXPECT_EQ(f.referenced_module_variables()[1], &v2);
 }
 
+TEST_F(FunctionTest, AddDuplicateEntryPoints) {
+  ast::type::VoidType void_type;
+  Function f("func", VariableList{}, &void_type);
+
+  f.add_ancestor_entry_point("main");
+  ASSERT_EQ(1u, f.ancestor_entry_points().size());
+  EXPECT_EQ("main", f.ancestor_entry_points()[0]);
+
+  f.add_ancestor_entry_point("main");
+  ASSERT_EQ(1u, f.ancestor_entry_points().size());
+  EXPECT_EQ("main", f.ancestor_entry_points()[0]);
+}
+
 TEST_F(FunctionTest, IsValid) {
   type::VoidType void_type;
   type::I32Type i32;
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index fa6a03c..82d45a4 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -188,9 +188,27 @@
   if (!DetermineFunctions(mod_->functions())) {
     return false;
   }
+
+  // Walk over the caller to callee information and update functions with which
+  // entry points call those functions.
+  for (const auto& ep : mod_->entry_points()) {
+    for (const auto& callee : caller_to_callee_[ep->function_name()]) {
+      set_entry_points(callee, ep->name());
+    }
+  }
+
   return true;
 }
 
+void TypeDeterminer::set_entry_points(const std::string& fn_name,
+                                      const std::string& ep_name) {
+  name_to_function_[fn_name]->add_ancestor_entry_point(ep_name);
+
+  for (const auto& callee : caller_to_callee_[fn_name]) {
+    set_entry_points(callee, ep_name);
+  }
+}
+
 bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
   for (const auto& func : funcs) {
     if (!DetermineFunction(func.get())) {
@@ -457,6 +475,10 @@
       imp->AddMethodId(ident->name(), ext_id);
       expr->func()->set_result_type(result_type);
     } else {
+      if (current_function_) {
+        caller_to_callee_[current_function_->name()].push_back(ident->name());
+      }
+
       // An identifier with a single name is a function call, not an import
       // lookup which we can handle with the regular identifier lookup.
       if (!DetermineResultType(ident)) {
diff --git a/src/type_determiner.h b/src/type_determiner.h
index 80f2397..ea9b9fc 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -105,6 +105,7 @@
  private:
   void set_error(const Source& src, const std::string& msg);
   void set_referenced_from_function_if_needed(ast::Variable* var);
+  void set_entry_points(const std::string& fn_name, const std::string& ep_name);
 
   bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
   bool DetermineAs(ast::AsExpression* expr);
@@ -123,6 +124,9 @@
   ScopeStack<ast::Variable*> variable_stack_;
   std::unordered_map<std::string, ast::Function*> name_to_function_;
   ast::Function* current_function_ = nullptr;
+
+  // Map from caller functions to callee functions.
+  std::unordered_map<std::string, std::vector<std::string>> caller_to_callee_;
 };
 
 }  // namespace tint
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 2fd9033..183bbc2 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -3746,5 +3746,109 @@
                          testing::Values(GLSLData{"sclamp", GLSLstd450SClamp},
                                          GLSLData{"uclamp", GLSLstd450UClamp}));
 
+TEST_F(TypeDeterminerTest, Function_EntryPoints) {
+  ast::type::F32Type f32;
+
+  // fn b() {}
+  // fn c() { b(); }
+  // fn a() { c(); }
+  // fn ep_1() { a(); b(); }
+  // fn ep_2() { c();}
+  //
+  // c -> {ep_1, ep_2}
+  // a -> {ep_1}
+  // b -> {ep_1, ep_2}
+  // ep_1 -> {}
+  // ep_2 -> {}
+
+  ast::VariableList params;
+  auto func_b = std::make_unique<ast::Function>("b", std::move(params), &f32);
+  auto* func_b_ptr = func_b.get();
+
+  ast::StatementList body;
+  func_b->set_body(std::move(body));
+
+  auto func_c = std::make_unique<ast::Function>("c", std::move(params), &f32);
+  auto* func_c_ptr = func_c.get();
+
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("second"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("b"),
+          ast::ExpressionList{})));
+  func_c->set_body(std::move(body));
+
+  auto func_a = std::make_unique<ast::Function>("a", std::move(params), &f32);
+  auto* func_a_ptr = func_a.get();
+
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("first"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("c"),
+          ast::ExpressionList{})));
+  func_a->set_body(std::move(body));
+
+  auto ep_1_func =
+      std::make_unique<ast::Function>("ep_1_func", std::move(params), &f32);
+  auto* ep_1_func_ptr = ep_1_func.get();
+
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("call_a"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("a"),
+          ast::ExpressionList{})));
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("call_b"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("b"),
+          ast::ExpressionList{})));
+  ep_1_func->set_body(std::move(body));
+
+  auto ep_2_func =
+      std::make_unique<ast::Function>("ep_2_func", std::move(params), &f32);
+  auto* ep_2_func_ptr = ep_2_func.get();
+
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("call_c"),
+      std::make_unique<ast::CallExpression>(
+          std::make_unique<ast::IdentifierExpression>("c"),
+          ast::ExpressionList{})));
+  ep_2_func->set_body(std::move(body));
+
+  auto ep_1 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex,
+                                                "ep_1", "ep_1_func");
+  auto ep_2 = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kVertex,
+                                                "ep_2", "ep_2_func");
+
+  mod()->AddFunction(std::move(func_b));
+  mod()->AddFunction(std::move(func_c));
+  mod()->AddFunction(std::move(func_a));
+  mod()->AddFunction(std::move(ep_1_func));
+  mod()->AddFunction(std::move(ep_2_func));
+
+  mod()->AddEntryPoint(std::move(ep_1));
+  mod()->AddEntryPoint(std::move(ep_2));
+
+  // Register the functions and calculate the callers
+  ASSERT_TRUE(td()->Determine()) << td()->error();
+
+  const auto& b_eps = func_b_ptr->ancestor_entry_points();
+  ASSERT_EQ(2u, b_eps.size());
+  EXPECT_EQ("ep_1", b_eps[0]);
+  EXPECT_EQ("ep_2", b_eps[1]);
+
+  const auto& a_eps = func_a_ptr->ancestor_entry_points();
+  ASSERT_EQ(1u, a_eps.size());
+  EXPECT_EQ("ep_1", a_eps[0]);
+
+  const auto& c_eps = func_c_ptr->ancestor_entry_points();
+  ASSERT_EQ(2u, c_eps.size());
+  EXPECT_EQ("ep_1", c_eps[0]);
+  EXPECT_EQ("ep_2", c_eps[1]);
+
+  EXPECT_TRUE(ep_1_func_ptr->ancestor_entry_points().empty());
+  EXPECT_TRUE(ep_2_func_ptr->ancestor_entry_points().empty());
+}
+
 }  // namespace
 }  // namespace tint