Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 193 additions & 6 deletions code_review_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self, db_path: str | Path) -> None:
)
self._conn.commit()
run_migrations(self._conn)
self._nxg_cache: nx.DiGraph | None = None
self._nxg_cache: tuple[nx.DiGraph, str] | None = None
self._cache_lock = threading.Lock()
# Query result caches with version-based invalidation
self._cache_version = 0
Expand Down Expand Up @@ -851,18 +851,205 @@ def _batch_get_nodes(self, qualified_names: set[str]) -> list[GraphNode]:

# --- Internal helpers ---

def _build_networkx_graph(self) -> nx.DiGraph:
"""Build (or return cached) in-memory NetworkX directed graph from all edges."""
def _build_networkx_graph(self, include_edge_kind: str | None = None) -> nx.DiGraph:
"""Build (or return cached) in-memory NetworkX directed graph from all edges.

Args:
include_edge_kind: If provided, only include edges of this kind.
"""
cache_key = include_edge_kind or "all"
with self._cache_lock:
if self._nxg_cache is not None:
return self._nxg_cache
cached_g, cached_key = self._nxg_cache
if cached_key == cache_key:
return cached_g
g: nx.DiGraph = nx.DiGraph()
rows = self._conn.execute("SELECT * FROM edges").fetchall()
if include_edge_kind:
rows = self._conn.execute(
"SELECT * FROM edges WHERE kind = ?", (include_edge_kind,)
).fetchall()
else:
rows = self._conn.execute("SELECT * FROM edges").fetchall()
for r in rows:
g.add_edge(r["source_qualified"], r["target_qualified"], kind=r["kind"])
self._nxg_cache = g
self._nxg_cache = (g, cache_key)
return g

def find_shortest_path(
self, source: str, target: str, edge_kind: str | None = None
) -> list[str] | None:
"""Find shortest path between two nodes using BFS.

Args:
source: Source node qualified name.
target: Target node qualified name.
edge_kind: Optional edge type filter (e.g., "CALLS").

Returns:
List of qualified names forming the path, or None if no path exists.
"""
g = self._build_networkx_graph(include_edge_kind=edge_kind)
try:
path = nx.shortest_path(g, source, target)
return path
except (nx.NetworkXNoPath, nx.NodeNotFound):
return None

def find_all_shortest_paths(
self, source: str, target: str, edge_kind: str | None = None
) -> list[list[str]]:
"""Find all shortest paths between two nodes.

Returns:
List of paths (each path is a list of qualified names).
"""
g = self._build_networkx_graph(include_edge_kind=edge_kind)
try:
return list(nx.all_shortest_paths(g, source, target))
except (nx.NetworkXNoPath, nx.NodeNotFound):
return []

def find_path_with_depth_limit(
self, source: str, target: str, max_depth: int = 3
) -> list[str] | None:
"""Find shortest path up to max_depth using BFS.

Returns:
Path list or None if no path within depth limit.
"""
g = self._build_networkx_graph()
try:
path = nx.shortest_path_length(g, source, target)
if path <= max_depth:
return nx.shortest_path(g, source, target)
except (nx.NetworkXNoPath, nx.NodeNotFound):
pass
return None

def get_pagerank(self, alpha: float = 0.85, max_iter: int = 100) -> dict[str, float]:
"""Compute PageRank scores for all nodes.

Args:
alpha: Damping factor (0-1).
max_iter: Maximum iterations.

Returns:
Dict mapping qualified name to PageRank score.
"""
g = self._build_networkx_graph()
try:
return nx.pagerank(g, alpha=alpha, max_iter=max_iter)
except nx.PowerIterationFailedConvergence:
return {}

def get_degree_centrality(self) -> dict[str, float]:
"""Compute degree centrality (in + out) for each node."""
g = self._build_networkx_graph()
in_deg = nx.in_degree_centrality(g)
out_deg = nx.out_degree_centrality(g)
return {n: in_deg.get(n, 0) + out_deg.get(n, 0) for n in g.nodes()}

def get_betweenness_centrality(self, k: int | None = None) -> dict[str, float]:
"""Compute betweenness centrality.

Args:
k: Sample k nodes for faster approximation (None = all nodes).

Returns:
Dict mapping qualified name to betweenness score.
"""
g = self._build_networkx_graph()
return nx.betweenness_centrality(g, k=k)

def find_strongly_connected_components(self) -> list[set[str]]:
"""Find all strongly connected components using Tarjan's algorithm.

Returns:
List of SCCs (each SCC is a set of qualified names).
"""
g = self._build_networkx_graph()
return list(nx.strongly_connected_components(g))

def find_cycles(self, max_length: int = 10) -> list[list[str]]:
"""Find all simple cycles in the graph up to max_length.

Warning: This can be expensive on large graphs. Consider limiting max_length.

Returns:
List of cycles (each cycle is a list of qualified names).
"""
g = self._build_networkx_graph()
try:
return list(nx.simple_cycles(g))[:1000] # Limit to prevent memory issues
except Exception:
return []

def has_cycle(self) -> bool:
"""Check if the graph has any cycles."""
g = self._build_networkx_graph()
try:
nx.find_cycle(g)
return True
except nx.NetworkXNoCycle:
return False

def get_cycle_for_node(self, node: str) -> list[str] | None:
"""Find a cycle that includes the given node, if any."""
g = self._build_networkx_graph()
try:
cycles = nx.simple_cycles(g)
for cycle in cycles:
if node in cycle:
return cycle
except Exception:
pass
return None

def topological_sort(self) -> list[str]:
"""Return nodes in topological order (works on DAG).

Raises if graph has cycles.
"""
g = self._build_networkx_graph()
try:
return list(nx.topological_sort(g))
except nx.NetworkXError:
return []

def find_leaf_nodes(self) -> list[str]:
"""Find nodes with no outgoing edges (sinks)."""
g = self._build_networkx_graph()
return [n for n in g.nodes() if g.out_degree(n) == 0]

def find_root_nodes(self) -> list[str]:
"""Find nodes with no incoming edges (sources)."""
g = self._build_networkx_graph()
return [n for n in g.nodes() if g.in_degree(n) == 0]

def get_connected_components(self) -> list[set[str]]:
"""Find weakly connected components."""
g = self._build_networkx_graph()
return list(nx.weakly_connected_components(g))

def get_node_importance(self, method: str = "pagerank") -> list[tuple[str, float]]:
"""Get nodes sorted by importance score.

Args:
method: "pagerank", "degree", or "betweenness"

Returns:
List of (qualified_name, score) tuples sorted by score descending.
"""
if method == "pagerank":
scores = self.get_pagerank()
elif method == "degree":
scores = self.get_degree_centrality()
elif method == "betweenness":
scores = self.get_betweenness_centrality()
else:
return []
return sorted(scores.items(), key=lambda x: x[1], reverse=True)

def _make_qualified(self, node: NodeInfo) -> str:
if node.kind == "File":
return node.file_path
Expand Down
Loading