TD: Fix O(2^n) of reachable-by-entry-point
Re-jig the code so that this can be performed in O(n).
Fixed: tint:245
Change-Id: I6dc341c0313e3a1c808f15c66e0c70a7339640e5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/43641
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index af104ca..8a77248 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -137,29 +137,9 @@
return false;
}
- // Walk over the caller to callee information and update functions with
- // which entry points call those functions.
- for (auto* func : builder_->AST().Functions()) {
- if (!func->IsEntryPoint()) {
- continue;
- }
- for (const auto& callee : caller_to_callee_[func->symbol()]) {
- set_entry_points(callee, func->symbol());
- }
- }
-
return true;
}
-void TypeDeterminer::set_entry_points(const Symbol& fn_sym, Symbol ep_sym) {
- auto* info = symbol_to_function_.at(fn_sym);
- info->ancestor_entry_points.add(ep_sym);
-
- for (const auto& callee : caller_to_callee_[fn_sym]) {
- set_entry_points(callee, ep_sym);
- }
-}
-
bool TypeDeterminer::DetermineFunctions(const ast::FunctionList& funcs) {
for (auto* func : funcs) {
if (!DetermineFunction(func)) {
@@ -439,9 +419,6 @@
}
} else {
if (current_function_) {
- caller_to_callee_[current_function_->declaration->symbol()].push_back(
- ident->symbol());
-
auto callee_func_it = symbol_to_function_.find(ident->symbol());
if (callee_func_it == symbol_to_function_.end()) {
if (current_function_->declaration->symbol() == ident->symbol()) {
@@ -457,6 +434,13 @@
}
auto* callee_func = callee_func_it->second;
+ // Note: Requires called functions to be resolved first.
+ // This is currently guaranteed as functions must be declared before use.
+ current_function_->transitive_calls.add(callee_func);
+ for (auto* transitive_call : callee_func->transitive_calls) {
+ current_function_->transitive_calls.add(transitive_call);
+ }
+
// We inherit any referenced variables from the callee.
for (auto* var : callee_func->referenced_module_vars) {
set_referenced_from_function_if_needed(var, false);
@@ -1004,6 +988,25 @@
void TypeDeterminer::CreateSemanticNodes() const {
auto& sem = builder_->Sem();
+ // Collate all the 'ancestor_entry_points' - this is a map of function symbol
+ // to all the entry points that transitively call the function.
+ std::unordered_map<Symbol, std::vector<Symbol>> ancestor_entry_points;
+ for (auto* func : builder_->AST().Functions()) {
+ auto it = function_to_info_.find(func);
+ if (it == function_to_info_.end()) {
+ continue; // Type determination has likely errored. Process what we can.
+ }
+
+ auto* info = it->second;
+ if (!func->IsEntryPoint()) {
+ continue;
+ }
+ for (auto* call : info->transitive_calls) {
+ auto& vec = ancestor_entry_points[call->declaration->symbol()];
+ vec.emplace_back(func->symbol());
+ }
+ }
+
// Create semantic nodes for all ast::Variables
for (auto it : variable_to_info_) {
auto* var = it.first;
@@ -1038,10 +1041,11 @@
for (auto it : function_to_info_) {
auto* func = it.first;
auto* info = it.second;
+
auto* sem_func = builder_->create<semantic::Function>(
info->declaration, remap_vars(info->referenced_module_vars),
remap_vars(info->local_referenced_module_vars),
- info->ancestor_entry_points);
+ ancestor_entry_points[func->symbol()]);
func_info_to_sem_func.emplace(info, sem_func);
sem.Add(func, sem_func);
}
diff --git a/src/type_determiner.h b/src/type_determiner.h
index fb1b7a9..677a21c 100644
--- a/src/type_determiner.h
+++ b/src/type_determiner.h
@@ -91,7 +91,9 @@
ast::Function* const declaration;
UniqueVector<VariableInfo*> referenced_module_vars;
UniqueVector<VariableInfo*> local_referenced_module_vars;
- UniqueVector<Symbol> ancestor_entry_points;
+
+ // List of transitive calls this function makes
+ UniqueVector<FunctionInfo*> transitive_calls;
};
/// Structure holding semantic information about an expression.
@@ -118,7 +120,8 @@
/// @param funcs the functions to check
/// @returns true if the determination was successful
bool DetermineFunctions(const ast::FunctionList& funcs);
- /// Determines type information for a function
+ /// Determines type information for a function. Requires all dependency
+ /// (callee) functions to have DetermineFunction() called on them first.
/// @param func the function to check
/// @returns true if the determination was successful
bool DetermineFunction(ast::Function* func);
@@ -162,7 +165,6 @@
uint32_t* id);
void set_referenced_from_function_if_needed(VariableInfo* var, bool local);
- void set_entry_points(const Symbol& fn_sym, Symbol ep_sym);
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
bool DetermineBinary(ast::BinaryExpression* expr);
@@ -202,9 +204,6 @@
semantic::Statement* current_statement_ = nullptr;
BlockAllocator<VariableInfo> variable_infos_;
BlockAllocator<FunctionInfo> function_infos_;
-
- // Map from caller functions to callee functions.
- std::unordered_map<Symbol, std::vector<Symbol>> caller_to_callee_;
};
} // namespace tint
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index edb95cf..75c11df 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -3269,6 +3269,53 @@
EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty());
}
+// Check for linear-time traversal of functions reachable from entry points.
+// See: crbug.com/tint/245
+TEST_F(TypeDeterminerTest, Function_EntryPoints_LinearTime) {
+ // fn lNa() { }
+ // fn lNb() { }
+ // ...
+ // fn l2a() { l3a(); l3b(); }
+ // fn l2b() { l3a(); l3b(); }
+ // fn l1a() { l2a(); l2b(); }
+ // fn l1b() { l2a(); l2b(); }
+ // fn main() { l1a(); l1b(); }
+
+ static constexpr int levels = 64;
+
+ auto fn_a = [](int level) { return "l" + std::to_string(level + 1) + "a"; };
+ auto fn_b = [](int level) { return "l" + std::to_string(level + 1) + "b"; };
+
+ Func(fn_a(levels), {}, ty.void_(), {}, {});
+ Func(fn_b(levels), {}, ty.void_(), {}, {});
+
+ for (int i = levels - 1; i >= 0; i--) {
+ Func(fn_a(i), {}, ty.void_(),
+ {
+ create<ast::CallStatement>(Call(fn_a(i + 1))),
+ create<ast::CallStatement>(Call(fn_b(i + 1))),
+ },
+ {});
+ Func(fn_b(i), {}, ty.void_(),
+ {
+ create<ast::CallStatement>(Call(fn_a(i + 1))),
+ create<ast::CallStatement>(Call(fn_b(i + 1))),
+ },
+ {});
+ }
+
+ Func("main", {}, ty.void_(),
+ {
+ create<ast::CallStatement>(Call(fn_a(0))),
+ create<ast::CallStatement>(Call(fn_b(0))),
+ },
+ {
+ create<ast::StageDecoration>(ast::PipelineStage::kVertex),
+ });
+
+ ASSERT_TRUE(td()->Determine()) << td()->error();
+}
+
using TypeDeterminerTextureIntrinsicTest =
TypeDeterminerTestWithParam<ast::intrinsic::test::TextureOverloadCase>;