Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<p>
<a href="https://github.com/2002yy/study-agent/actions/workflows/ci.yml"><img src="https://github.com/2002yy/study-agent/actions/workflows/ci.yml/badge.svg" alt="CI"></a>
<img src="https://img.shields.io/badge/python-3.12-blue" alt="Python 3.12">
<img src="https://img.shields.io/badge/tests-262%20passed-green" alt="262 tests passed">
<img src="https://img.shields.io/badge/tests-265%20passed-green" alt="265 tests passed">
</p>

A local AI learning assistant with long-term memory, role-based group chat,
Expand Down Expand Up @@ -31,7 +31,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点不是简单调用
- **SSRF protection** for article fetching, **detect-secrets** in CI
- **Batched session logging** and multi-layer caching for performance
- **Performance budget**: mode-based `max_tokens` bounds on the main chat, WeChat, and news LLM paths
- **262 pytest tests**, Ruff clean, GitHub Actions CI workflow
- **265 pytest tests**, Ruff clean, GitHub Actions CI workflow

For a detailed breakdown of the stack and engineering highlights, see [Technical Stack & Engineering Highlights](docs/TECH_STACK.md).

Expand Down Expand Up @@ -255,7 +255,7 @@ pip-compile requirements-dev.in # 重新锁定开发依赖
## 测试

```bash
pytest tests/ -v # current local baseline: 262 passed
pytest tests/ -v # current local baseline: 265 passed
pytest tests/ --cov=src # 覆盖率
ruff check src/ tests/ # linting
mypy --explicit-package-bases src/ # CI soft check; may report type debt
Expand Down
2 changes: 1 addition & 1 deletion docs/INTERVIEW_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点在多 Provider 模
2. **长期记忆写入安全** — safe writer + preview/confirm 机制,防止不可逆的记忆污染
3. **联网搜索来源追溯** — Feed registry / RSS 多源聚合 → URL safety matrix → 文章正文三层提取 → LLM digest → pipeline trace 全过程来源可回溯
4. **Streamlit 重渲染性能优化** — 多层缓存策略、按模式批量落盘、主链路 token 预算控制
5. **CI / Ruff / detect-secrets 工程检查** — 262 pytest tests、Ruff clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断
5. **CI / Ruff / detect-secrets 工程检查** — 265 pytest tests、Ruff clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断

## 可讲亮点

Expand Down
14 changes: 11 additions & 3 deletions docs/RAG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,23 @@ Regression coverage lives in `tests/test_rag.py` and verifies:
- Empty-result and miss accounting
- Unknown retrieval mode rejection

P4-B adds API/query diagnostics:

- Retrieval mode, `top_k`, `min_score` and tokenized query terms
- Candidate count and returned result count
- Per-result rank, chunk id, source path, matched terms and score breakdown
- Optional one-query evaluation when `/rag/query` receives `expected_sources`

## Next Steps

### P4: Retrieval Quality Loop

Goal: prove retrieval quality before expanding the stack.

- Add a small gold fixture set with queries, expected sources and expected terms.
- Track `recall@k`, mean reciprocal rank, source hit rate and empty-result rate.
- Surface retrieval debug data in tests and API responses before adding more UI polish.
- [x] Add a small gold fixture set with queries, expected sources and expected terms.
- [x] Track `recall@k`, mean reciprocal rank, source hit rate and empty-result rate.
- [x] Surface retrieval debug data in tests and API responses before adding more UI polish.
- [ ] Add a Streamlit source/debug panel for inspecting score breakdowns.
- Keep the first evaluation layer LLM-free so CI can catch retrieval regressions deterministically.

### P5: Real Embedding Backend
Expand Down
8 changes: 4 additions & 4 deletions docs/TESTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Current verified baseline:

| Check | Status | Evidence |
|---|---|---|
| pytest | Passed | `262 passed` locally on 2026-06-05 |
| pytest | Passed | `265 passed` locally on 2026-06-05 |
| Ruff | Passed | `python -m ruff check .` clean locally on 2026-06-04 |
| Package helper | Passed | `python tools/package_project_helper.py . NUL 0` locally on 2026-06-04 |
| mypy | Soft check, not clean | `python -m mypy --explicit-package-bases src/` reported 18 errors locally on 2026-06-04 |
Expand All @@ -24,9 +24,9 @@ Current verified baseline:
| **News URL safety** | `test_url_normalizer.py`, `test_link_resolver.py` | 28 |
| **News pipeline trace / audit** | `test_news_pipeline_trace.py`, `test_news_audit.py` | 5 |
| **Feed registry / health** | `test_feed_registry.py`, `test_feed_diagnostics.py` | 9 |
| **RAG MVP** | `test_rag.py` | 20 |
| **RAG MVP** | `test_rag.py` | 22 |
| **RAG evaluation** | `test_rag_eval.py` | 5 |
| **FastAPI RAG endpoints** | `test_api.py` | 5 |
| **FastAPI RAG endpoints** | `test_api.py` | 6 |
| **Architecture flows** | `test_architecture_flows.py` | 12 |
| **WeChat decoupling** | `test_wechat_decoupling.py` | 4 |
| **Sidebar rerun** | `test_sidebar_global_rerun.py` | 12 |
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_flush_uses_safe_writer():
## Running Tests

```bash
python -m pytest # current baseline: 262 passed
python -m pytest # current baseline: 265 passed
pytest tests/ -v # Verbose
pytest tests/ --cov=src # Coverage
python -m ruff check . # Linting
Expand Down
38 changes: 34 additions & 4 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from src.rag import build_rag_context, format_rag_sources, index_documents, query_documents
from src.rag.index import DEFAULT_RAG_INDEX_PATH
from src.rag import build_rag_context, format_rag_sources, index_documents
from src.rag.eval import RagEvalCase, evaluate_case
from src.rag.index import DEFAULT_RAG_INDEX_PATH, load_rag_index
from src.rag.service import build_rag_debug, search_documents


class HealthResponse(BaseModel):
Expand Down Expand Up @@ -36,6 +38,8 @@ class RagQueryRequest(BaseModel):
min_score: float = Field(default=0.01, ge=0)
retrieval_mode: str = Field(default="hybrid")
context_max_chars: int = Field(default=3000, gt=0, le=20_000)
expected_sources: list[str] = Field(default_factory=list)
expected_terms: list[str] = Field(default_factory=list)


class RagQueryResponse(BaseModel):
Expand All @@ -45,6 +49,8 @@ class RagQueryResponse(BaseModel):
context: str
sources: str
results: list[dict[str, Any]]
debug: dict[str, Any]
evaluation: dict[str, Any] | None = None


app = FastAPI(title="Study Agent API", version="0.1.0")
Expand Down Expand Up @@ -88,13 +94,35 @@ def build_rag_index_endpoint(request: RagIndexRequest) -> RagIndexResponse:
@app.post("/rag/query", response_model=RagQueryResponse)
def query_rag_endpoint(request: RagQueryRequest) -> RagQueryResponse:
try:
results = query_documents(
index = load_rag_index(_index_path(request.index_path))
results = search_documents(
index,
request.query,
index_path=_index_path(request.index_path),
top_k=request.top_k,
min_score=request.min_score,
retrieval_mode=request.retrieval_mode,
)
debug = build_rag_debug(
index,
request.query,
results,
retrieval_mode=request.retrieval_mode,
top_k=request.top_k,
min_score=request.min_score,
)
evaluation = None
if request.expected_sources:
evaluation = evaluate_case(
index,
RagEvalCase(
query=request.query,
expected_sources=tuple(request.expected_sources),
expected_terms=tuple(request.expected_terms),
top_k=request.top_k,
retrieval_mode=request.retrieval_mode,
),
min_score=request.min_score,
).to_dict()
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail="RAG index not found") from exc
except ValueError as exc:
Expand All @@ -107,6 +135,8 @@ def query_rag_endpoint(request: RagQueryRequest) -> RagQueryResponse:
context=build_rag_context(results, max_chars=request.context_max_chars),
sources=format_rag_sources(results),
results=[result.to_dict() for result in results],
debug=debug,
evaluation=evaluation,
)


Expand Down
4 changes: 4 additions & 0 deletions src/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
load_eval_cases,
)
from src.rag.service import (
build_rag_debug,
build_rag_context,
format_rag_sources,
index_documents,
query_documents,
search_documents,
)
from src.rag.vector import (
cosine_similarity,
Expand All @@ -28,6 +30,7 @@
)

__all__ = [
"build_rag_debug",
"build_rag_context",
"build_rag_index",
"cosine_similarity",
Expand All @@ -46,4 +49,5 @@
"search_rag_index_hybrid",
"search_rag_index",
"search_rag_index_vector",
"search_documents",
]
135 changes: 133 additions & 2 deletions src/rag/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@

from collections.abc import Sequence
from pathlib import Path
from typing import Any

from src.rag.index import (
DEFAULT_RAG_INDEX_PATH,
_document_frequency,
_score_chunk,
_tokenize,
build_rag_index,
load_rag_index,
save_rag_index,
search_rag_index,
)
from src.rag.schema import RagIndex, RagSearchResult
from src.rag.vector import search_rag_index_hybrid, search_rag_index_vector
from src.rag.vector import (
cosine_similarity,
embed_text,
search_rag_index_hybrid,
search_rag_index_vector,
)

RETRIEVAL_MODES = {"lexical", "vector", "hybrid"}
HYBRID_LEXICAL_WEIGHT = 0.7


def index_documents(
Expand Down Expand Up @@ -41,15 +53,134 @@ def query_documents(
) -> list[RagSearchResult]:
"""Search the persisted local RAG index."""
index = load_rag_index(index_path)
return search_documents(
index,
query,
top_k=top_k,
min_score=min_score,
retrieval_mode=retrieval_mode,
)


def search_documents(
index: RagIndex,
query: str,
*,
top_k: int = 5,
min_score: float = 0.01,
retrieval_mode: str = "lexical",
) -> list[RagSearchResult]:
"""Search an in-memory RAG index."""
if retrieval_mode == "lexical":
return search_rag_index(index, query, top_k=top_k, min_score=min_score)
if retrieval_mode == "vector":
return search_rag_index_vector(index, query, top_k=top_k, min_score=min_score)
if retrieval_mode == "hybrid":
return search_rag_index_hybrid(index, query, top_k=top_k, min_score=min_score)
return search_rag_index_hybrid(
index,
query,
top_k=top_k,
min_score=min_score,
lexical_weight=HYBRID_LEXICAL_WEIGHT,
)
raise ValueError(f"Unsupported RAG retrieval mode: {retrieval_mode}")


def _lexical_scores(index: RagIndex, query: str) -> dict[str, float]:
if not index.chunks:
return {}
df = _document_frequency(index.chunks)
return {
chunk.chunk_id: round(_score_chunk(query, chunk, df, len(index.chunks))[0], 6)
for chunk in index.chunks
}


def _vector_scores(index: RagIndex, query: str) -> dict[str, float]:
query_vector = embed_text(query)
if not any(query_vector):
return {chunk.chunk_id: 0.0 for chunk in index.chunks}
return {
chunk.chunk_id: round(cosine_similarity(query_vector, embed_text(chunk.text)), 6)
for chunk in index.chunks
Comment on lines +103 to +105
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Report clipped hybrid vector scores

For hybrid searches, search_rag_index_hybrid() builds vector_scores from search_rag_index_vector(..., min_score=0.0), so negative cosine scores are omitted and contribute 0.0 to the combined score. This debug path instead records the raw cosine for every chunk, so any lexical hit whose vector similarity is negative will show a vector_score that was not actually used and the breakdown cannot explain the returned combined_score.

Useful? React with 👍 / 👎.

}


def _score_breakdown(
result: RagSearchResult,
*,
retrieval_mode: str,
lexical_scores: dict[str, float],
vector_scores: dict[str, float],
max_lexical_score: float,
) -> dict[str, float]:
chunk_id = result.chunk.chunk_id
lexical_score = lexical_scores.get(chunk_id, 0.0)
lexical_normalized = lexical_score / max_lexical_score if max_lexical_score > 0 else 0.0
vector_score = vector_scores.get(chunk_id, 0.0)

if retrieval_mode == "lexical":
return {"lexical_score": round(lexical_score, 6)}
if retrieval_mode == "vector":
return {"vector_score": round(vector_score, 6)}
if retrieval_mode == "hybrid":
return {
"lexical_weight": HYBRID_LEXICAL_WEIGHT,
"lexical_score": round(lexical_score, 6),
"lexical_normalized": round(lexical_normalized, 6),
"vector_score": round(vector_score, 6),
"combined_score": round(result.score, 6),
}
raise ValueError(f"Unsupported RAG retrieval mode: {retrieval_mode}")


def build_rag_debug(
index: RagIndex,
query: str,
results: list[RagSearchResult],
*,
retrieval_mode: str,
top_k: int,
min_score: float,
) -> dict[str, Any]:
"""Build explainable retrieval diagnostics for API and evaluation views."""
if retrieval_mode not in RETRIEVAL_MODES:
raise ValueError(f"Unsupported RAG retrieval mode: {retrieval_mode}")

lexical_scores = _lexical_scores(index, query)
vector_scores = _vector_scores(index, query)
max_lexical_score = max(lexical_scores.values(), default=0.0)
query_terms = tuple(sorted(set(_tokenize(query))))

return {
"retrieval_mode": retrieval_mode,
"top_k": top_k,
"min_score": min_score,
"candidate_count": len(index.chunks),
"returned_count": len(results),
"query_terms": list(query_terms),
"empty_query": not query_terms,
"results": [
{
"rank": rank,
"chunk_id": result.chunk.chunk_id,
"source_path": result.chunk.source_path,
"title": result.chunk.title,
"score": result.score,
"matched_terms": list(result.matched_terms),
"score_breakdown": _score_breakdown(
result,
retrieval_mode=retrieval_mode,
lexical_scores=lexical_scores,
vector_scores=vector_scores,
max_lexical_score=max_lexical_score,
),
}
for rank, result in enumerate(results, start=1)
],
}


def build_rag_context(
results: list[RagSearchResult],
*,
Expand Down
Loading
Loading