Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(filename)s - %(asctime)s - %(levelname)s - %(message)s')

def _define_ontology() -> Ontology:
def define_ontology() -> Ontology:
# Build ontology:
ontology = Ontology()

Expand Down Expand Up @@ -233,14 +233,14 @@ def _define_ontology() -> Ontology:
return ontology

# Global ontology
ontology = _define_ontology()
ontology = define_ontology()

def _create_kg_agent(repo_name: str):
model_name = os.getenv('MODEL_NAME', 'gemini/gemini-flash-lite-latest')

model = LiteModel(model_name)

#ontology = _define_ontology()
#ontology = define_ontology()
code_graph_kg = KnowledgeGraph(
name=repo_name,
ontology=ontology,
Expand Down
33 changes: 33 additions & 0 deletions api/mcp/code_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""MCP-side GraphRAG prompt overrides (T10).

Today this module is a thin re-export of ``api.prompts``. The point is the
**seam**: when the MCP ``ask`` tool needs prompt framing tuned for
"the user is an AI agent inspecting a codebase" instead of
"a human chatting about their repo", divergence happens *here* without
touching the existing FastAPI ``/api/chat`` prompts.

Until that day, every prompt below is identical to its ``api.prompts``
counterpart — verified by ``tests/mcp/test_code_prompts.py``.
"""

from __future__ import annotations

from api.prompts import (
CYPHER_GEN_PROMPT,
CYPHER_GEN_SYSTEM,
GRAPH_QA_PROMPT,
GRAPH_QA_SYSTEM,
)


__all__ = [
"CYPHER_GEN_SYSTEM",
"CYPHER_GEN_PROMPT",
"GRAPH_QA_SYSTEM",
"GRAPH_QA_PROMPT",
]


# TODO(MCP): start diverging here when agent-vs-human framing matters.
# Keep `api/prompts.py` as the canonical reference for the FastAPI
# chat endpoint and override the MCP-facing variants in this module.
85 changes: 85 additions & 0 deletions api/mcp/graphrag_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""GraphRAG init for the MCP ``ask`` tool (T9, refined by T11).

The MCP ``ask`` tool needs one ``KnowledgeGraph`` instance per
``(project, branch)`` to drive GraphRAG's NL→Cypher→QA round-trip. Building
one is non-trivial — ontology, model, prompts, FalkorDB connection — and
the existing ``api/llm.py`` builder bakes in a single repo name at module
import.

This module exposes:

* :func:`get_or_create_kg` — process-wide cache keyed by
``(project, branch)``. Cheap to call; one instance reused across many
``ask`` invocations.
* :func:`reset_cache` — used in tests to drop the cache between runs.

The ontology is intentionally reused from ``api.llm.define_ontology`` — it's
200+ lines of hand-tuned descriptions of File/Class/Function entities that
the LLM relies on to generate good Cypher. Replacing it with
``Ontology.from_kg_graph()`` (auto-extraction) is a regression.
"""

from __future__ import annotations

import os
from typing import Tuple

from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig
from graphrag_sdk.models.litellm import LiteModel

from api.graph import compose_graph_name
from api.llm import define_ontology
from api.mcp.code_prompts import (
CYPHER_GEN_PROMPT,
CYPHER_GEN_SYSTEM,
GRAPH_QA_PROMPT,
GRAPH_QA_SYSTEM,
)


_CACHE: dict[Tuple[str, str], KnowledgeGraph] = {}


def _make_model() -> LiteModel:
"""Build the LiteModel from ``$MODEL_NAME`` (same default as api/llm.py)."""
model_name = os.getenv("MODEL_NAME", "gemini/gemini-flash-lite-latest")
return LiteModel(model_name)


def get_or_create_kg(project_name: str, branch: str = "_default") -> KnowledgeGraph:
"""Return a cached :class:`KnowledgeGraph` for ``(project, branch)``.

Two calls with the same ``(project, branch)`` are guaranteed to return
the **same** instance (identity preserved) so callers don't pay the
construction cost on every ``ask``.

The underlying graph name uses the T17 convention
``code:{project}:{branch}`` so per-branch indexing works end-to-end.
"""
key = (project_name, branch)
cached = _CACHE.get(key)
if cached is not None:
return cached

graph_name = compose_graph_name(project_name, branch)
model = _make_model()
kg = KnowledgeGraph(
name=graph_name,
ontology=define_ontology(),
model_config=KnowledgeGraphModelConfig.with_model(model),
host=os.getenv("FALKORDB_HOST", "localhost"),
port=int(os.getenv("FALKORDB_PORT", 6379)),
username=os.getenv("FALKORDB_USERNAME", None),
password=os.getenv("FALKORDB_PASSWORD", None),
cypher_system_instruction=CYPHER_GEN_SYSTEM,
qa_system_instruction=GRAPH_QA_SYSTEM,
cypher_gen_prompt=CYPHER_GEN_PROMPT,
qa_prompt=GRAPH_QA_PROMPT,
)
_CACHE[key] = kg
return kg


def reset_cache() -> None:
"""Drop the per-process KG cache. Tests only."""
_CACHE.clear()
2 changes: 1 addition & 1 deletion api/mcp/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
``api.mcp.server``. Import this package to register all tools.
"""

from . import structural # noqa: F401 (registers tools on import)
from . import ask, structural # noqa: F401 (registers tools on import)
93 changes: 93 additions & 0 deletions api/mcp/tools/ask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""MCP ``ask`` tool — NL → Cypher → QA via GraphRAG (T11).

This is the strategic differentiator vs purely structural code-graph MCP
servers: the agent asks a natural-language question, and we return the
LLM's answer plus the actual Cypher that was executed (for transparency
and learning).

Two LLM round-trips bracket one FalkorDB query:

1. **LLM #1 (cypher gen):** question + ontology → Cypher
2. **FalkorDB:** execute Cypher → rows of nodes
3. **LLM #2 (QA synthesis):** question + rows → natural-language answer

The graph itself never goes to the LLM — only the schema and per-query
results — which is what makes this scale to huge codebases.
"""

from __future__ import annotations

import asyncio
import logging
from typing import Any, Optional

from ..graphrag_init import get_or_create_kg
from ..server import app


logger = logging.getLogger(__name__)


def _normalize_response(raw: Any) -> dict[str, Any]:
"""Coerce graphrag-sdk's chat response into the MCP payload shape.

graphrag-sdk shapes its return as a ``dict`` with at least a
``response`` (the natural-language answer) and, depending on the
SDK version, ``cypher`` / ``context``. We surface ``cypher_query``
and ``context_nodes`` regardless — the design doc requires the
Cypher to be visible so agents can debug, learn, and decide whether
the query was sensible.
"""
if not isinstance(raw, dict):
return {"answer": str(raw), "cypher_query": None, "context_nodes": []}

answer = raw.get("response") or raw.get("answer") or ""
cypher = raw.get("cypher_query") or raw.get("cypher") or raw.get("query")
ctx = (
raw.get("context_nodes")
or raw.get("context")
or raw.get("results")
or []
)
return {
"answer": answer,
"cypher_query": cypher,
"context_nodes": ctx,
}


@app.tool(
name="ask",
description=(
"Ask a natural-language question about the indexed codebase. "
"Powered by GraphRAG: the question is translated to Cypher, "
"executed against the FalkorDB code graph, and the rows are "
"summarised in English. The executed Cypher is returned in "
"`cypher_query` so the agent can verify the answer and learn the "
"schema."
),
)
async def ask(
question: str,
project: str,
branch: Optional[str] = None,
) -> dict[str, Any]:
kg = get_or_create_kg(project, branch or "_default")
loop = asyncio.get_running_loop()

def _ask_sync() -> Any:
chat = kg.chat_session()
return chat.send_message(question)

try:
raw = await loop.run_in_executor(None, _ask_sync)
except Exception as exc: # surface as a structured failure, not a crash
logger.exception("ask failed for project=%s branch=%s", project, branch)
return {
"answer": "",
"cypher_query": None,
"context_nodes": [],
"error": str(exc),
}

return _normalize_response(raw)
131 changes: 131 additions & 0 deletions tests/mcp/test_ask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""T11 — MCP ``ask`` tool tests (mocked LLM)."""

from __future__ import annotations

import json
from unittest.mock import MagicMock, patch

import pytest


pytestmark = pytest.mark.anyio


@pytest.fixture
def anyio_backend() -> str:
return "asyncio"


@pytest.fixture(autouse=True)
def _reset_kg_cache():
from api.mcp.graphrag_init import reset_cache

reset_cache()
yield
reset_cache()


async def test_ask_registered():
from api.mcp.server import app

names = {t.name for t in await app.list_tools()}
assert "ask" in names


async def test_ask_returns_normalised_payload():
"""Mock the entire KG; ensure the ask tool shapes its response
correctly: {answer, cypher_query, context_nodes}.
"""
from api.mcp.tools.ask import ask

fake_chat = MagicMock()
fake_chat.send_message.return_value = {
"response": "service is called by entrypoint.",
"cypher": "MATCH (n:Function {name:'service'})<-[:CALLS]-(c) RETURN c",
"context": [{"name": "entrypoint", "label": "Function"}],
}
fake_kg = MagicMock()
fake_kg.chat_session.return_value = fake_chat

with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg):
result = await ask(question="who calls service?", project="p", branch="b")

assert result["answer"] == "service is called by entrypoint."
assert "MATCH" in (result["cypher_query"] or "")
assert result["context_nodes"] == [{"name": "entrypoint", "label": "Function"}]
assert "error" not in result

fake_kg.chat_session.assert_called_once()
fake_chat.send_message.assert_called_once_with("who calls service?")


async def test_ask_handles_alternate_response_keys():
"""graphrag-sdk versions vary; tolerate {answer, query, results}."""
from api.mcp.tools.ask import ask

fake_chat = MagicMock()
fake_chat.send_message.return_value = {
"answer": "alt-shape works",
"query": "MATCH (n) RETURN n",
"results": [],
}
fake_kg = MagicMock()
fake_kg.chat_session.return_value = fake_chat

with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg):
result = await ask(question="anything", project="p")

assert result["answer"] == "alt-shape works"
assert result["cypher_query"] == "MATCH (n) RETURN n"
assert result["context_nodes"] == []


async def test_ask_handles_string_response():
from api.mcp.tools.ask import ask

fake_chat = MagicMock()
fake_chat.send_message.return_value = "plain string answer"
fake_kg = MagicMock()
fake_kg.chat_session.return_value = fake_chat

with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg):
result = await ask(question="anything", project="p")

assert result["answer"] == "plain string answer"
assert result["cypher_query"] is None
assert result["context_nodes"] == []


async def test_ask_surfaces_errors_as_payload_not_raise():
"""Tool crashes must return a structured error so the agent doesn't
see a transport exception."""
from api.mcp.tools.ask import ask

fake_chat = MagicMock()
fake_chat.send_message.side_effect = RuntimeError("model unavailable")
fake_kg = MagicMock()
fake_kg.chat_session.return_value = fake_chat

with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg):
result = await ask(question="anything", project="p")

assert result["answer"] == ""
assert result["error"] == "model unavailable"


async def test_ask_response_is_json_serialisable():
from api.mcp.tools.ask import ask

fake_chat = MagicMock()
fake_chat.send_message.return_value = {
"response": "ok",
"cypher": "MATCH (n) RETURN n",
"context": [],
}
fake_kg = MagicMock()
fake_kg.chat_session.return_value = fake_chat

with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg):
result = await ask(question="q", project="p")

json.dumps(result) # must not raise
Loading