diff --git a/codeflash/models/call_graph.py b/codeflash/models/call_graph.py index ee2521c4f..b0d800dca 100644 --- a/codeflash/models/call_graph.py +++ b/codeflash/models/call_graph.py @@ -98,15 +98,27 @@ def descendants(self, node: FunctionNode, max_depth: int | None = None) -> set[F def ancestors(self, node: FunctionNode, max_depth: int | None = None) -> set[FunctionNode]: visited: set[FunctionNode] = set() - queue: deque[tuple[FunctionNode, int]] = deque([(node, 0)]) - while queue: - current, depth = queue.popleft() - if max_depth is not None and depth >= max_depth: - continue - for edge in self.callers_of(current): - if edge.caller not in visited: - visited.add(edge.caller) - queue.append((edge.caller, depth + 1)) + reverse_map = self.reverse + + if max_depth is None: + queue: deque[FunctionNode] = deque([node]) + while queue: + current = queue.popleft() + for edge in reverse_map.get(current, []): + if edge.caller not in visited: + visited.add(edge.caller) + queue.append(edge.caller) + else: + queue_with_depth: deque[tuple[FunctionNode, int]] = deque([(node, 0)]) + while queue_with_depth: + current, depth = queue_with_depth.popleft() + if depth >= max_depth: + continue + for edge in reverse_map.get(current, []): + if edge.caller not in visited: + visited.add(edge.caller) + queue_with_depth.append((edge.caller, depth + 1)) + return visited def subgraph(self, nodes: set[FunctionNode]) -> CallGraph: