From ca59dd564497a835bc77ab3f2eef4a7b741ebbff Mon Sep 17 00:00:00 2001 From: 2002yy <15135142681@163.com> Date: Fri, 5 Jun 2026 17:36:27 +0800 Subject: [PATCH] Add RAG query debug details --- README.md | 6 +- docs/INTERVIEW_NOTES.md | 2 +- docs/RAG.md | 14 ++++- docs/TESTING.md | 8 +-- src/api.py | 38 +++++++++-- src/rag/__init__.py | 4 ++ src/rag/service.py | 135 +++++++++++++++++++++++++++++++++++++++- tests/test_api.py | 29 +++++++++ tests/test_rag.py | 54 +++++++++++++++- 9 files changed, 272 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index ae689d9..db30e08 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
A local AI learning assistant with long-term memory, role-based group chat, @@ -31,7 +31,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点不是简单调用 - **SSRF protection** for article fetching, **detect-secrets** in CI - **Batched session logging** and multi-layer caching for performance - **Performance budget**: mode-based `max_tokens` bounds on the main chat, WeChat, and news LLM paths -- **262 pytest tests**, Ruff clean, GitHub Actions CI workflow +- **265 pytest tests**, Ruff clean, GitHub Actions CI workflow For a detailed breakdown of the stack and engineering highlights, see [Technical Stack & Engineering Highlights](docs/TECH_STACK.md). @@ -255,7 +255,7 @@ pip-compile requirements-dev.in # 重新锁定开发依赖 ## 测试 ```bash -pytest tests/ -v # current local baseline: 262 passed +pytest tests/ -v # current local baseline: 265 passed pytest tests/ --cov=src # 覆盖率 ruff check src/ tests/ # linting mypy --explicit-package-bases src/ # CI soft check; may report type debt diff --git a/docs/INTERVIEW_NOTES.md b/docs/INTERVIEW_NOTES.md index 81858d3..30b9253 100644 --- a/docs/INTERVIEW_NOTES.md +++ b/docs/INTERVIEW_NOTES.md @@ -10,7 +10,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点在多 Provider 模 2. **长期记忆写入安全** — safe writer + preview/confirm 机制,防止不可逆的记忆污染 3. **联网搜索来源追溯** — Feed registry / RSS 多源聚合 → URL safety matrix → 文章正文三层提取 → LLM digest → pipeline trace 全过程来源可回溯 4. **Streamlit 重渲染性能优化** — 多层缓存策略、按模式批量落盘、主链路 token 预算控制 -5. **CI / Ruff / detect-secrets 工程检查** — 262 pytest tests、Ruff clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断 +5. **CI / Ruff / detect-secrets 工程检查** — 265 pytest tests、Ruff clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断 ## 可讲亮点 diff --git a/docs/RAG.md b/docs/RAG.md index 98dec43..2cbe085 100644 --- a/docs/RAG.md +++ b/docs/RAG.md @@ -116,15 +116,23 @@ Regression coverage lives in `tests/test_rag.py` and verifies: - Empty-result and miss accounting - Unknown retrieval mode rejection +P4-B adds API/query diagnostics: + +- Retrieval mode, `top_k`, `min_score` and tokenized query terms +- Candidate count and returned result count +- Per-result rank, chunk id, source path, matched terms and score breakdown +- Optional one-query evaluation when `/rag/query` receives `expected_sources` + ## Next Steps ### P4: Retrieval Quality Loop Goal: prove retrieval quality before expanding the stack. -- Add a small gold fixture set with queries, expected sources and expected terms. -- Track `recall@k`, mean reciprocal rank, source hit rate and empty-result rate. -- Surface retrieval debug data in tests and API responses before adding more UI polish. +- [x] Add a small gold fixture set with queries, expected sources and expected terms. +- [x] Track `recall@k`, mean reciprocal rank, source hit rate and empty-result rate. +- [x] Surface retrieval debug data in tests and API responses before adding more UI polish. +- [ ] Add a Streamlit source/debug panel for inspecting score breakdowns. - Keep the first evaluation layer LLM-free so CI can catch retrieval regressions deterministically. ### P5: Real Embedding Backend diff --git a/docs/TESTING.md b/docs/TESTING.md index 9909fbf..ad0e41d 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -6,7 +6,7 @@ Current verified baseline: | Check | Status | Evidence | |---|---|---| -| pytest | Passed | `262 passed` locally on 2026-06-05 | +| pytest | Passed | `265 passed` locally on 2026-06-05 | | Ruff | Passed | `python -m ruff check .` clean locally on 2026-06-04 | | Package helper | Passed | `python tools/package_project_helper.py . NUL 0` locally on 2026-06-04 | | mypy | Soft check, not clean | `python -m mypy --explicit-package-bases src/` reported 18 errors locally on 2026-06-04 | @@ -24,9 +24,9 @@ Current verified baseline: | **News URL safety** | `test_url_normalizer.py`, `test_link_resolver.py` | 28 | | **News pipeline trace / audit** | `test_news_pipeline_trace.py`, `test_news_audit.py` | 5 | | **Feed registry / health** | `test_feed_registry.py`, `test_feed_diagnostics.py` | 9 | -| **RAG MVP** | `test_rag.py` | 20 | +| **RAG MVP** | `test_rag.py` | 22 | | **RAG evaluation** | `test_rag_eval.py` | 5 | -| **FastAPI RAG endpoints** | `test_api.py` | 5 | +| **FastAPI RAG endpoints** | `test_api.py` | 6 | | **Architecture flows** | `test_architecture_flows.py` | 12 | | **WeChat decoupling** | `test_wechat_decoupling.py` | 4 | | **Sidebar rerun** | `test_sidebar_global_rerun.py` | 12 | @@ -76,7 +76,7 @@ def test_flush_uses_safe_writer(): ## Running Tests ```bash -python -m pytest # current baseline: 262 passed +python -m pytest # current baseline: 265 passed pytest tests/ -v # Verbose pytest tests/ --cov=src # Coverage python -m ruff check . # Linting diff --git a/src/api.py b/src/api.py index 3dc2e30..7364729 100644 --- a/src/api.py +++ b/src/api.py @@ -6,8 +6,10 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field -from src.rag import build_rag_context, format_rag_sources, index_documents, query_documents -from src.rag.index import DEFAULT_RAG_INDEX_PATH +from src.rag import build_rag_context, format_rag_sources, index_documents +from src.rag.eval import RagEvalCase, evaluate_case +from src.rag.index import DEFAULT_RAG_INDEX_PATH, load_rag_index +from src.rag.service import build_rag_debug, search_documents class HealthResponse(BaseModel): @@ -36,6 +38,8 @@ class RagQueryRequest(BaseModel): min_score: float = Field(default=0.01, ge=0) retrieval_mode: str = Field(default="hybrid") context_max_chars: int = Field(default=3000, gt=0, le=20_000) + expected_sources: list[str] = Field(default_factory=list) + expected_terms: list[str] = Field(default_factory=list) class RagQueryResponse(BaseModel): @@ -45,6 +49,8 @@ class RagQueryResponse(BaseModel): context: str sources: str results: list[dict[str, Any]] + debug: dict[str, Any] + evaluation: dict[str, Any] | None = None app = FastAPI(title="Study Agent API", version="0.1.0") @@ -88,13 +94,35 @@ def build_rag_index_endpoint(request: RagIndexRequest) -> RagIndexResponse: @app.post("/rag/query", response_model=RagQueryResponse) def query_rag_endpoint(request: RagQueryRequest) -> RagQueryResponse: try: - results = query_documents( + index = load_rag_index(_index_path(request.index_path)) + results = search_documents( + index, request.query, - index_path=_index_path(request.index_path), top_k=request.top_k, min_score=request.min_score, retrieval_mode=request.retrieval_mode, ) + debug = build_rag_debug( + index, + request.query, + results, + retrieval_mode=request.retrieval_mode, + top_k=request.top_k, + min_score=request.min_score, + ) + evaluation = None + if request.expected_sources: + evaluation = evaluate_case( + index, + RagEvalCase( + query=request.query, + expected_sources=tuple(request.expected_sources), + expected_terms=tuple(request.expected_terms), + top_k=request.top_k, + retrieval_mode=request.retrieval_mode, + ), + min_score=request.min_score, + ).to_dict() except FileNotFoundError as exc: raise HTTPException(status_code=404, detail="RAG index not found") from exc except ValueError as exc: @@ -107,6 +135,8 @@ def query_rag_endpoint(request: RagQueryRequest) -> RagQueryResponse: context=build_rag_context(results, max_chars=request.context_max_chars), sources=format_rag_sources(results), results=[result.to_dict() for result in results], + debug=debug, + evaluation=evaluation, ) diff --git a/src/rag/__init__.py b/src/rag/__init__.py index 60000ba..aa4f200 100644 --- a/src/rag/__init__.py +++ b/src/rag/__init__.py @@ -15,10 +15,12 @@ load_eval_cases, ) from src.rag.service import ( + build_rag_debug, build_rag_context, format_rag_sources, index_documents, query_documents, + search_documents, ) from src.rag.vector import ( cosine_similarity, @@ -28,6 +30,7 @@ ) __all__ = [ + "build_rag_debug", "build_rag_context", "build_rag_index", "cosine_similarity", @@ -46,4 +49,5 @@ "search_rag_index_hybrid", "search_rag_index", "search_rag_index_vector", + "search_documents", ] diff --git a/src/rag/service.py b/src/rag/service.py index b42ec51..741e23d 100644 --- a/src/rag/service.py +++ b/src/rag/service.py @@ -2,16 +2,28 @@ from collections.abc import Sequence from pathlib import Path +from typing import Any from src.rag.index import ( DEFAULT_RAG_INDEX_PATH, + _document_frequency, + _score_chunk, + _tokenize, build_rag_index, load_rag_index, save_rag_index, search_rag_index, ) from src.rag.schema import RagIndex, RagSearchResult -from src.rag.vector import search_rag_index_hybrid, search_rag_index_vector +from src.rag.vector import ( + cosine_similarity, + embed_text, + search_rag_index_hybrid, + search_rag_index_vector, +) + +RETRIEVAL_MODES = {"lexical", "vector", "hybrid"} +HYBRID_LEXICAL_WEIGHT = 0.7 def index_documents( @@ -41,15 +53,134 @@ def query_documents( ) -> list[RagSearchResult]: """Search the persisted local RAG index.""" index = load_rag_index(index_path) + return search_documents( + index, + query, + top_k=top_k, + min_score=min_score, + retrieval_mode=retrieval_mode, + ) + + +def search_documents( + index: RagIndex, + query: str, + *, + top_k: int = 5, + min_score: float = 0.01, + retrieval_mode: str = "lexical", +) -> list[RagSearchResult]: + """Search an in-memory RAG index.""" if retrieval_mode == "lexical": return search_rag_index(index, query, top_k=top_k, min_score=min_score) if retrieval_mode == "vector": return search_rag_index_vector(index, query, top_k=top_k, min_score=min_score) if retrieval_mode == "hybrid": - return search_rag_index_hybrid(index, query, top_k=top_k, min_score=min_score) + return search_rag_index_hybrid( + index, + query, + top_k=top_k, + min_score=min_score, + lexical_weight=HYBRID_LEXICAL_WEIGHT, + ) raise ValueError(f"Unsupported RAG retrieval mode: {retrieval_mode}") +def _lexical_scores(index: RagIndex, query: str) -> dict[str, float]: + if not index.chunks: + return {} + df = _document_frequency(index.chunks) + return { + chunk.chunk_id: round(_score_chunk(query, chunk, df, len(index.chunks))[0], 6) + for chunk in index.chunks + } + + +def _vector_scores(index: RagIndex, query: str) -> dict[str, float]: + query_vector = embed_text(query) + if not any(query_vector): + return {chunk.chunk_id: 0.0 for chunk in index.chunks} + return { + chunk.chunk_id: round(cosine_similarity(query_vector, embed_text(chunk.text)), 6) + for chunk in index.chunks + } + + +def _score_breakdown( + result: RagSearchResult, + *, + retrieval_mode: str, + lexical_scores: dict[str, float], + vector_scores: dict[str, float], + max_lexical_score: float, +) -> dict[str, float]: + chunk_id = result.chunk.chunk_id + lexical_score = lexical_scores.get(chunk_id, 0.0) + lexical_normalized = lexical_score / max_lexical_score if max_lexical_score > 0 else 0.0 + vector_score = vector_scores.get(chunk_id, 0.0) + + if retrieval_mode == "lexical": + return {"lexical_score": round(lexical_score, 6)} + if retrieval_mode == "vector": + return {"vector_score": round(vector_score, 6)} + if retrieval_mode == "hybrid": + return { + "lexical_weight": HYBRID_LEXICAL_WEIGHT, + "lexical_score": round(lexical_score, 6), + "lexical_normalized": round(lexical_normalized, 6), + "vector_score": round(vector_score, 6), + "combined_score": round(result.score, 6), + } + raise ValueError(f"Unsupported RAG retrieval mode: {retrieval_mode}") + + +def build_rag_debug( + index: RagIndex, + query: str, + results: list[RagSearchResult], + *, + retrieval_mode: str, + top_k: int, + min_score: float, +) -> dict[str, Any]: + """Build explainable retrieval diagnostics for API and evaluation views.""" + if retrieval_mode not in RETRIEVAL_MODES: + raise ValueError(f"Unsupported RAG retrieval mode: {retrieval_mode}") + + lexical_scores = _lexical_scores(index, query) + vector_scores = _vector_scores(index, query) + max_lexical_score = max(lexical_scores.values(), default=0.0) + query_terms = tuple(sorted(set(_tokenize(query)))) + + return { + "retrieval_mode": retrieval_mode, + "top_k": top_k, + "min_score": min_score, + "candidate_count": len(index.chunks), + "returned_count": len(results), + "query_terms": list(query_terms), + "empty_query": not query_terms, + "results": [ + { + "rank": rank, + "chunk_id": result.chunk.chunk_id, + "source_path": result.chunk.source_path, + "title": result.chunk.title, + "score": result.score, + "matched_terms": list(result.matched_terms), + "score_breakdown": _score_breakdown( + result, + retrieval_mode=retrieval_mode, + lexical_scores=lexical_scores, + vector_scores=vector_scores, + max_lexical_score=max_lexical_score, + ), + } + for rank, result in enumerate(results, start=1) + ], + } + + def build_rag_context( results: list[RagSearchResult], *, diff --git a/tests/test_api.py b/tests/test_api.py index 64adb91..ba72989 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -49,6 +49,9 @@ def test_rag_index_and_query_endpoints(tmp_path): assert query_data["result_count"] == 1 assert "[1] notes" in query_data["context"] assert query_data["results"][0]["chunk"]["source_path"] == str(document) + assert query_data["debug"]["retrieval_mode"] == "hybrid" + assert query_data["debug"]["candidate_count"] == 1 + assert query_data["debug"]["results"][0]["score_breakdown"]["combined_score"] > 0 def test_rag_alias_queries_existing_index(tmp_path): @@ -83,6 +86,32 @@ def test_rag_query_endpoint_rejects_unknown_mode(tmp_path): assert "Unsupported RAG retrieval mode" in response.json()["detail"] +def test_rag_query_endpoint_returns_optional_evaluation(tmp_path): + client = TestClient(app) + document = tmp_path / "notes.md" + index_path = tmp_path / "rag_index.json" + document.write_text("Evaluation expects this cited local source.", encoding="utf-8") + client.post("/rag/index", json={"paths": [str(document)], "index_path": str(index_path)}) + + response = client.post( + "/rag/query", + json={ + "query": "cited local source", + "index_path": str(index_path), + "retrieval_mode": "hybrid", + "expected_sources": [document.name], + "expected_terms": ["cited", "source"], + "top_k": 2, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["evaluation"]["hit"] is True + assert data["evaluation"]["recall_at_k"] == 1.0 + assert data["debug"]["returned_count"] >= 1 + + def test_rag_index_endpoint_reports_missing_files(tmp_path): client = TestClient(app) diff --git a/tests/test_rag.py b/tests/test_rag.py index 5ee5a19..7d1a292 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -6,7 +6,13 @@ import pytest -from src.rag import build_rag_context, format_rag_sources, index_documents, query_documents +from src.rag import ( + build_rag_context, + build_rag_debug, + format_rag_sources, + index_documents, + query_documents, +) from src.rag.chunker import chunk_document from src.rag.index import build_rag_index, load_rag_index, search_rag_index from src.rag.loader import load_document @@ -207,6 +213,52 @@ def test_query_documents_rejects_unknown_retrieval_mode(tmp_path): query_documents("retrieval", index_path=index_path, retrieval_mode="semantic") +def test_build_rag_debug_explains_hybrid_scores(tmp_path): + python_doc = tmp_path / "python.md" + cooking_doc = tmp_path / "cooking.md" + python_doc.write_text("HTTP requests sessions reuse connections.", encoding="utf-8") + cooking_doc.write_text("Tomato pasta sauce needs basil.", encoding="utf-8") + index = build_rag_index([python_doc, cooking_doc], max_chars=200, overlap_chars=0) + results = search_rag_index_hybrid(index, "requests connections", top_k=1) + + debug = build_rag_debug( + index, + "requests connections", + results, + retrieval_mode="hybrid", + top_k=1, + min_score=0.01, + ) + + assert debug["candidate_count"] == 2 + assert debug["returned_count"] == 1 + assert debug["query_terms"] == ["connections", "requests"] + breakdown = debug["results"][0]["score_breakdown"] + assert breakdown["lexical_weight"] == 0.7 + assert breakdown["combined_score"] == results[0].score + assert breakdown["lexical_score"] > 0 + assert breakdown["vector_score"] > 0 + + +def test_build_rag_debug_marks_empty_queries(tmp_path): + path = tmp_path / "notes.md" + path.write_text("Local retrieval.", encoding="utf-8") + index = build_rag_index([path], max_chars=200, overlap_chars=0) + + debug = build_rag_debug( + index, + "", + [], + retrieval_mode="lexical", + top_k=3, + min_score=0.01, + ) + + assert debug["empty_query"] is True + assert debug["candidate_count"] == 1 + assert debug["returned_count"] == 0 + + def test_local_hash_embeddings_are_deterministic(): left = embed_text("requests session reuse") right = embed_text("requests session reuse")