TD: Fix O(2^n) of reachable-by-entry-point

Re-jig the code so that this can be performed in O(n).

Fixed: tint:245
Change-Id: I6dc341c0313e3a1c808f15c66e0c70a7339640e5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/43641
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index af104ca..8a77248 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -137,29 +137,9 @@
     return false;
   }
 
-  // Walk over the caller to callee information and update functions with
-  // which entry points call those functions.
-  for (auto* func : builder_->AST().Functions()) {
-    if (!func->IsEntryPoint()) {
-      continue;
-    }
-    for (const auto& callee : caller_to_callee_[func->symbol()]) {
-      set_entry_points(callee, func->symbol());
-    }
-  }
-
   return true;
 }
 
-void TypeDeterminer::set_entry_points(const Symbol& fn_sym, Symbol ep_sym) {
-  auto* info = symbol_to_function_.at(fn_sym);
-  info->ancestor_entry_points.add(ep_sym);
-
-  for (const auto& callee : caller_to_callee_[fn_sym]) {
-    set_entry_points(callee, ep_sym);
-  }
-}
-
 bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
   for (auto* func : funcs) {
     if (!DetermineFunction(func)) {
@@ -439,9 +419,6 @@
     }
   } else {
     if (current_function_) {
-      caller_to_callee_[current_function_->declaration->symbol()].push_back(
-          ident->symbol());
-
       auto callee_func_it = symbol_to_function_.find(ident->symbol());
       if (callee_func_it == symbol_to_function_.end()) {
         if (current_function_->declaration->symbol() == ident->symbol()) {
@@ -457,6 +434,13 @@
       }
       auto* callee_func = callee_func_it->second;
 
+      // Note: Requires called functions to be resolved first.
+      // This is currently guaranteed as functions must be declared before use.
+      current_function_->transitive_calls.add(callee_func);
+      for (auto* transitive_call : callee_func->transitive_calls) {
+        current_function_->transitive_calls.add(transitive_call);
+      }
+
       // We inherit any referenced variables from the callee.
       for (auto* var : callee_func->referenced_module_vars) {
         set_referenced_from_function_if_needed(var, false);
@@ -1004,6 +988,25 @@
 void TypeDeterminer::CreateSemanticNodes() const {
   auto& sem = builder_->Sem();
 
+  // Collate all the 'ancestor_entry_points' - this is a map of function symbol
+  // to all the entry points that transitively call the function.
+  std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points;
+  for (auto* func : builder_->AST().Functions()) {
+    auto it = function_to_info_.find(func);
+    if (it == function_to_info_.end()) {
+      continue;  // Type determination has likely errored. Process what we can.
+    }
+
+    auto* info = it->second;
+    if (!func->IsEntryPoint()) {
+      continue;
+    }
+    for (auto* call : info->transitive_calls) {
+      auto& vec = ancestor_entry_points[call->declaration->symbol()];
+      vec.emplace_back(func->symbol());
+    }
+  }
+
   // Create semantic nodes for all ast::Variables
   for (auto it : variable_to_info_) {
     auto* var = it.first;
@@ -1038,10 +1041,11 @@
   for (auto it : function_to_info_) {
     auto* func = it.first;
     auto* info = it.second;
+
     auto* sem_func = builder_->create<semantic::Function>(
         info->declaration, remap_vars(info->referenced_module_vars),
         remap_vars(info->local_referenced_module_vars),
-        info->ancestor_entry_points);
+        ancestor_entry_points[func->symbol()]);
     func_info_to_sem_func.emplace(info, sem_func);
     sem.Add(func, sem_func);
   }
diff --git a/src/type_determiner.h b/src/type_determiner.h
index fb1b7a9..677a21c 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -91,7 +91,9 @@
     ast::Function* const declaration;
     UniqueVector<VariableInfo*> referenced_module_vars;
     UniqueVector<VariableInfo*> local_referenced_module_vars;
-    UniqueVector<Symbol> ancestor_entry_points;
+
+    // List of transitive calls this function makes
+    UniqueVector<FunctionInfo*> transitive_calls;
   };
 
   /// Structure holding semantic information about an expression.
@@ -118,7 +120,8 @@
   /// @param funcs the functions to check
   /// @returns true if the determination was successful
   bool DetermineFunctions(const ast::FunctionList& funcs);
-  /// Determines type information for a function
+  /// Determines type information for a function. Requires all dependency
+  /// (callee) functions to have DetermineFunction() called on them first.
   /// @param func the function to check
   /// @returns true if the determination was successful
   bool DetermineFunction(ast::Function* func);
@@ -162,7 +165,6 @@
                             uint32_t* id);
 
   void set_referenced_from_function_if_needed(VariableInfo* var, bool local);
-  void set_entry_points(const Symbol& fn_sym, Symbol ep_sym);
 
   bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
   bool DetermineBinary(ast::BinaryExpression* expr);
@@ -202,9 +204,6 @@
   semantic::Statement* current_statement_ = nullptr;
   BlockAllocator<VariableInfo> variable_infos_;
   BlockAllocator<FunctionInfo> function_infos_;
-
-  // Map from caller functions to callee functions.
-  std::unordered_map<Symbol, std::vector<Symbol>> caller_to_callee_;
 };
 
 }  // namespace tint
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index edb95cf..75c11df 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -3269,6 +3269,53 @@
   EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty());
 }
 
+// Check for linear-time traversal of functions reachable from entry points.
+// See: crbug.com/tint/245
+TEST_F(TypeDeterminerTest, Function_EntryPoints_LinearTime) {
+  // fn lNa() { }
+  // fn lNb() { }
+  // ...
+  // fn l2a() { l3a(); l3b(); }
+  // fn l2b() { l3a(); l3b(); }
+  // fn l1a() { l2a(); l2b(); }
+  // fn l1b() { l2a(); l2b(); }
+  // fn main() { l1a(); l1b(); }
+
+  static constexpr int levels = 64;
+
+  auto fn_a = [](int level) { return "l" + std::to_string(level + 1) + "a"; };
+  auto fn_b = [](int level) { return "l" + std::to_string(level + 1) + "b"; };
+
+  Func(fn_a(levels), {}, ty.void_(), {}, {});
+  Func(fn_b(levels), {}, ty.void_(), {}, {});
+
+  for (int i = levels - 1; i >= 0; i--) {
+    Func(fn_a(i), {}, ty.void_(),
+         {
+             create<ast::CallStatement>(Call(fn_a(i + 1))),
+             create<ast::CallStatement>(Call(fn_b(i + 1))),
+         },
+         {});
+    Func(fn_b(i), {}, ty.void_(),
+         {
+             create<ast::CallStatement>(Call(fn_a(i + 1))),
+             create<ast::CallStatement>(Call(fn_b(i + 1))),
+         },
+         {});
+  }
+
+  Func("main", {}, ty.void_(),
+       {
+           create<ast::CallStatement>(Call(fn_a(0))),
+           create<ast::CallStatement>(Call(fn_b(0))),
+       },
+       {
+           create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+       });
+
+  ASSERT_TRUE(td()->Determine()) << td()->error();
+}
+
 using TypeDeterminerTextureIntrinsicTest =
     TypeDeterminerTestWithParam<ast::intrinsic::test::TextureOverloadCase>;