Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions api/mcp/tools/structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,242 @@ def _payload(project) -> dict[str, Any]:
}

return await loop.run_in_executor(None, _do_index)


# ---------------------------------------------------------------------------
# T5 — get_callers / get_callees / get_dependencies
# ---------------------------------------------------------------------------


def _project_arg(project: str, branch: Optional[str]):
"""Return an :class:`AsyncGraphQuery` for ``(project, branch)``."""
from api.graph import AsyncGraphQuery

return AsyncGraphQuery(project, branch=branch)


def _node_summary(n: Any) -> dict[str, Any]:
"""Normalize a FalkorDB Node (or already-encoded dict) to a flat payload.

``encode_node`` returns ``{id, labels, properties: {...}}`` because Node
properties live on a nested attribute. Agents want a flat record, and
they also want a single ``label`` (the meaningful one — File, Class,
Function — not the fulltext-index marker ``Searchable``).
"""
if hasattr(n, "properties"):
props = dict(n.properties or {})
labels = list(n.labels or [])
node_id = getattr(n, "id", None)
else:
d = dict(n)
props = dict(d.get("properties") or {})
labels = list(d.get("labels") or [])
node_id = d.get("id")

label = next((lbl for lbl in labels if lbl != "Searchable"), None)
return {
"id": node_id,
"name": props.get("name"),
"label": label,
"file": props.get("path"),
"line": props.get("src_start"),
}


def _coerce_node_id(symbol_id: Any) -> int:
"""Accept int or stringified int; raise ValueError otherwise.

The MCP wire format is JSON; agents sometimes hand back the id as a
string. Be permissive on input, strict on type after parsing.
"""
if isinstance(symbol_id, bool): # bool is an int subclass; reject loudly
raise ValueError(f"symbol_id must be an integer, got bool: {symbol_id!r}")
if isinstance(symbol_id, int):
return symbol_id
if isinstance(symbol_id, str) and symbol_id.lstrip("-").isdigit():
return int(symbol_id)
raise ValueError(f"symbol_id must be an integer id, got: {symbol_id!r}")


async def _neighbors_payload(
project: str,
branch: Optional[str],
symbol_id: Any,
rel: str,
direction: str,
limit: int,
) -> list[dict[str, Any]]:
"""Shared implementation for caller/callee/dependency tools.

``direction`` is ``IN`` (incoming edges, e.g. callers) or ``OUT``
(outgoing edges, e.g. callees). When ``IN`` we run the inverse Cypher
``(neighbor)-[:rel]->(target)``; ``AsyncGraphQuery.get_neighbors`` only
walks outgoing edges, so we inline the Cypher here for symmetry.
"""
node_id = _coerce_node_id(symbol_id)
g = _project_arg(project, branch)
try:
if direction == "OUT":
q = (
f"MATCH (n)-[e:{rel}]->(dest) "
f"WHERE ID(n) = $sid "
f"RETURN dest, type(e) AS rel "
f"LIMIT $limit"
)
elif direction == "IN":
q = (
f"MATCH (src)-[e:{rel}]->(n) "
f"WHERE ID(n) = $sid "
f"RETURN src AS dest, type(e) AS rel "
f"LIMIT $limit"
)
else:
raise ValueError(f"direction must be IN or OUT, got: {direction!r}")

res = await g._query(q, {"sid": node_id, "limit": int(limit)})
out: list[dict[str, Any]] = []
for row in res.result_set:
entry = _node_summary(row[0])
entry["relation"] = row[1]
entry["direction"] = direction
out.append(entry)
return out
finally:
await g.close()


@app.tool(
name="get_callers",
description=(
"Return functions that call the given symbol (incoming CALLS edges). "
"`symbol_id` is the integer node id returned by `search_code` or "
"other tools."
),
)
async def get_callers(
symbol_id: Any,
project: str,
branch: Optional[str] = None,
limit: int = 50,
) -> list[dict[str, Any]]:
return await _neighbors_payload(project, branch, symbol_id, "CALLS", "IN", limit)


@app.tool(
name="get_callees",
description=(
"Return functions that the given symbol calls (outgoing CALLS edges)."
),
)
async def get_callees(
symbol_id: Any,
project: str,
branch: Optional[str] = None,
limit: int = 50,
) -> list[dict[str, Any]]:
return await _neighbors_payload(project, branch, symbol_id, "CALLS", "OUT", limit)


@app.tool(
name="get_dependencies",
description=(
"Return outgoing neighbors of the given symbol across any of the "
"specified relation types (default: IMPORTS, CALLS, DEFINES). "
"Useful for 'what does this depend on' queries."
),
)
async def get_dependencies(
symbol_id: Any,
project: str,
branch: Optional[str] = None,
rels: Optional[list[str]] = None,
limit: int = 50,
) -> list[dict[str, Any]]:
if rels is None:
rels = ["IMPORTS", "CALLS", "DEFINES"]
# Aggregate across relations; preserve ordering and dedupe by id.
seen: set[Any] = set()
out: list[dict[str, Any]] = []
for rel in rels:
rows = await _neighbors_payload(project, branch, symbol_id, rel, "OUT", limit)
for row in rows:
key = (row.get("id"), row.get("relation"))
if key in seen:
continue
seen.add(key)
out.append(row)
if len(out) >= limit:
return out
return out


# ---------------------------------------------------------------------------
# T7 — find_path
# ---------------------------------------------------------------------------


@app.tool(
name="find_path",
description=(
"Return up to `max_paths` CALLS-path sequences from `source_id` to "
"`dest_id`. Useful for 'how does A reach B' questions. Returns an "
"empty list when no path exists."
),
)
async def find_path(
source_id: Any,
dest_id: Any,
project: str,
branch: Optional[str] = None,
max_paths: int = 10,
) -> list[dict[str, Any]]:
src = _coerce_node_id(source_id)
dst = _coerce_node_id(dest_id)
g = _project_arg(project, branch)
try:
raw = await g.find_paths(src, dst)
finally:
await g.close()

# ``AsyncGraphQuery.find_paths`` returns each path as an alternating
# [node, edge, node, edge, ..., node] list; we strip edges and surface
# only the node sequence — that's what agents typically want.
paths: list[dict[str, Any]] = []
for entry in raw[:max_paths]:
node_seq = [
_node_summary(x)
for x in entry
# Edges in the alternating list carry a top-level ``relation``
# key (from ``encode_edge``); nodes carry ``properties``.
if isinstance(x, dict) and "properties" in x
]
paths.append({"path": node_seq})
return paths


# ---------------------------------------------------------------------------
# T8 — search_code
# ---------------------------------------------------------------------------


@app.tool(
name="search_code",
description=(
"Prefix-search for symbols (functions, classes, files) whose name "
"starts with `prefix`. Backed by FalkorDB's full-text index. The "
"agent typically calls this first to discover symbol ids for the "
"navigation tools (`get_callers`, `find_path`, ...)."
),
)
async def search_code(
prefix: str,
project: str,
branch: Optional[str] = None,
limit: int = 20,
) -> list[dict[str, Any]]:
g = _project_arg(project, branch)
try:
raw = await g.prefix_search(prefix)
finally:
await g.close()
return [_node_summary(node) for node in raw[:limit]]
Loading