From 14cdcaaf8731d182dfac514176893806675da864 Mon Sep 17 00:00:00 2001 From: dontmindaditya Date: Sun, 5 Apr 2026 03:57:45 -0700 Subject: [PATCH] asyncio --- code_review_graph/flows.py | 34 ++++++++++++++++++++++++++++++++-- code_review_graph/parser.py | 10 ++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/code_review_graph/flows.py b/code_review_graph/flows.py index 8c7374c..724d6da 100644 --- a/code_review_graph/flows.py +++ b/code_review_graph/flows.py @@ -45,6 +45,16 @@ re.compile(r"^handle_"), ] +# Name patterns for async entry points. +_ASYNC_ENTRY_PATTERNS: list[re.Pattern[str]] = [ + re.compile(r"^async_main$"), + re.compile(r"^run$"), + re.compile(r"^start$"), + re.compile(r"^serve$"), + re.compile(r"^listen$"), + re.compile(r"^bootstrap$"), +] + # --------------------------------------------------------------------------- # Entry-point detection @@ -73,13 +83,22 @@ def _matches_entry_name(node: GraphNode) -> bool: return False +def _matches_async_entry_name(node: GraphNode) -> bool: + """Return True if *node*'s name matches an async entry-point pattern.""" + for pat in _ASYNC_ENTRY_PATTERNS: + if pat.search(node.name): + return True + return False + + def detect_entry_points(store: GraphStore) -> list[GraphNode]: """Find functions that are entry points in the graph. An entry point is a Function/Test node that either: 1. Has no incoming CALLS edges (true root), or 2. Has a framework decorator (e.g. ``@app.get``), or - 3. Matches a conventional name pattern (``main``, ``test_*``, etc.). + 3. Matches a conventional name pattern (``main``, ``test_*``, etc.), or + 4. Is an async function with a matching async entry pattern. """ # Build a set of all qualified names that are CALLS targets. called_qnames = store.get_all_call_targets() @@ -105,6 +124,10 @@ def detect_entry_points(store: GraphStore) -> list[GraphNode]: if _matches_entry_name(node): is_entry = True + # Async function with async entry pattern match. + if node.extra.get("is_async") and _matches_async_entry_name(node): + is_entry = True + if is_entry and node.qualified_name not in seen_qn: entry_points.append(node) seen_qn.add(node.qualified_name) @@ -130,6 +153,8 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]: - file_count: number of distinct files touched - files: list of distinct file paths - criticality: computed criticality score (0.0-1.0) + - is_async: whether the entry point is async + - async_call_count: number of async calls (await) in the flow """ entry_points = detect_entry_points(store) flows: list[dict] = [] @@ -139,6 +164,7 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]: path_qnames: list[str] = [] visited: set[str] = set() queue: deque[tuple[str, int]] = deque() + async_call_count = 0 # Seed with the entry point itself. queue.append((ep.qualified_name, 0)) @@ -158,8 +184,10 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]: # Follow forward CALLS edges. edges = store.get_edges_by_source(current_qn) for edge in edges: - if edge.kind != "CALLS": + if edge.kind not in ("CALLS", "AWAITS"): continue + if edge.kind == "AWAITS": + async_call_count += 1 target_qn = edge.target_qualified if target_qn in visited: continue @@ -192,6 +220,8 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]: "file_count": len(files), "files": files, "criticality": 0.0, + "is_async": ep.extra.get("is_async", False), + "async_call_count": async_call_count, } flow["criticality"] = compute_criticality(flow, store) flows.append(flow) diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index 2e19cea..db0fdae 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -851,6 +851,14 @@ def _resolve_call_targets( _MAX_AST_DEPTH = 180 # Guard against pathologically nested source files _MAX_TEST_DESCRIPTION_LEN = 200 # Cap test description length in node names + def _is_async_function(self, node, language: str) -> bool: + """Return True if the function/method definition node is async.""" + if language in ("python", "javascript", "typescript", "tsx"): + for child in node.children: + if child.type == "async": + return True + return False + def _get_test_description(self, call_node, source: bytes) -> Optional[str]: """Extract the first string argument from a test runner call node.""" for child in call_node.children: @@ -1546,6 +1554,7 @@ def _extract_functions( qualified = self._qualify(name, file_path, enclosing_class) params = self._get_params(child, language, source) ret_type = self._get_return_type(child, language, source) + is_async = self._is_async_function(child, language) node = NodeInfo( kind=kind, @@ -1558,6 +1567,7 @@ def _extract_functions( params=params, return_type=ret_type, is_test=is_test, + is_async=is_async, ) nodes.append(node)