Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions code_review_graph/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@
re.compile(r"^handle_"),
]

# Name patterns for async entry points.
_ASYNC_ENTRY_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"^async_main$"),
re.compile(r"^run$"),
re.compile(r"^start$"),
re.compile(r"^serve$"),
re.compile(r"^listen$"),
re.compile(r"^bootstrap$"),
]


# ---------------------------------------------------------------------------
# Entry-point detection
Expand Down Expand Up @@ -73,13 +83,22 @@ def _matches_entry_name(node: GraphNode) -> bool:
return False


def _matches_async_entry_name(node: GraphNode) -> bool:
"""Return True if *node*'s name matches an async entry-point pattern."""
for pat in _ASYNC_ENTRY_PATTERNS:
if pat.search(node.name):
return True
return False


def detect_entry_points(store: GraphStore) -> list[GraphNode]:
"""Find functions that are entry points in the graph.

An entry point is a Function/Test node that either:
1. Has no incoming CALLS edges (true root), or
2. Has a framework decorator (e.g. ``@app.get``), or
3. Matches a conventional name pattern (``main``, ``test_*``, etc.).
3. Matches a conventional name pattern (``main``, ``test_*``, etc.), or
4. Is an async function with a matching async entry pattern.
"""
# Build a set of all qualified names that are CALLS targets.
called_qnames = store.get_all_call_targets()
Expand All @@ -105,6 +124,10 @@ def detect_entry_points(store: GraphStore) -> list[GraphNode]:
if _matches_entry_name(node):
is_entry = True

# Async function with async entry pattern match.
if node.extra.get("is_async") and _matches_async_entry_name(node):
is_entry = True

if is_entry and node.qualified_name not in seen_qn:
entry_points.append(node)
seen_qn.add(node.qualified_name)
Expand All @@ -130,6 +153,8 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]:
- file_count: number of distinct files touched
- files: list of distinct file paths
- criticality: computed criticality score (0.0-1.0)
- is_async: whether the entry point is async
- async_call_count: number of async calls (await) in the flow
"""
entry_points = detect_entry_points(store)
flows: list[dict] = []
Expand All @@ -139,6 +164,7 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]:
path_qnames: list[str] = []
visited: set[str] = set()
queue: deque[tuple[str, int]] = deque()
async_call_count = 0

# Seed with the entry point itself.
queue.append((ep.qualified_name, 0))
Expand All @@ -158,8 +184,10 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]:
# Follow forward CALLS edges.
edges = store.get_edges_by_source(current_qn)
for edge in edges:
if edge.kind != "CALLS":
if edge.kind not in ("CALLS", "AWAITS"):
continue
if edge.kind == "AWAITS":
async_call_count += 1
target_qn = edge.target_qualified
if target_qn in visited:
continue
Expand Down Expand Up @@ -192,6 +220,8 @@ def trace_flows(store: GraphStore, max_depth: int = 15) -> list[dict]:
"file_count": len(files),
"files": files,
"criticality": 0.0,
"is_async": ep.extra.get("is_async", False),
"async_call_count": async_call_count,
}
flow["criticality"] = compute_criticality(flow, store)
flows.append(flow)
Expand Down
10 changes: 10 additions & 0 deletions code_review_graph/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,14 @@ def _resolve_call_targets(
_MAX_AST_DEPTH = 180 # Guard against pathologically nested source files
_MAX_TEST_DESCRIPTION_LEN = 200 # Cap test description length in node names

def _is_async_function(self, node, language: str) -> bool:
"""Return True if the function/method definition node is async."""
if language in ("python", "javascript", "typescript", "tsx"):
for child in node.children:
if child.type == "async":
return True
return False

def _get_test_description(self, call_node, source: bytes) -> Optional[str]:
"""Extract the first string argument from a test runner call node."""
for child in call_node.children:
Expand Down Expand Up @@ -1546,6 +1554,7 @@ def _extract_functions(
qualified = self._qualify(name, file_path, enclosing_class)
params = self._get_params(child, language, source)
ret_type = self._get_return_type(child, language, source)
is_async = self._is_async_function(child, language)

node = NodeInfo(
kind=kind,
Expand All @@ -1558,6 +1567,7 @@ def _extract_functions(
params=params,
return_type=ret_type,
is_test=is_test,
is_async=is_async,
)
nodes.append(node)

Expand Down
Loading