Skip to content

Commit 21c26c3

Browse files
chengjieclaude
andcommitted
feat: 添加 CORS 支持和列表 API(修复前端跨域和数据加载问题)
## 修改内容 ### 1. 添加 CORS Middleware(main.py) - 配置允许的前端域名(生产 + 本地开发) - 支持 OPTIONS 预检请求 - 允许的方法:GET, POST, PUT, DELETE, OPTIONS - 允许的 Headers:Content-Type, Authorization, X-Request-ID ### 2. 新增任务列表 API(api/task.py) - 端点:GET /tasks - 功能:分页、按状态过滤、排序 - 参数验证:Literal 类型 + Query 约束 - 分页元数据:total_pages, has_next, has_prev ### 3. 新增文档列表 API(api/documents.py) - 端点:GET /documents - 功能:使用 LightRAG 原生 get_docs_paginated() - 支持:分页、状态过滤、排序 - 端点:GET /documents/status_counts - 功能:返回各状态的文档数量统计 ## 技术细节 - ✅ 验证了 LightRAG doc_status API 的实际行为(返回 tuple[list, int]) - ✅ 使用 Literal 类型限制枚举值(status, sort_field, sort_direction) - ✅ 添加参数验证(page_size ≤ 100, page ≤ 10000) - ✅ 完善错误处理(501 Not Implemented, 500 Internal Server Error) - ✅ 本地测试通过(CORS 预检返回 200,列表 API 正常工作) ## 解决的问题 - 🐛 修复前端 CORS 预检失败(OPTIONS 返回 405) - 🐛 修复前端刷新后列表为空(缺少列表 API) - 🐛 修复前端无法获取文档状态统计 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 0f297da commit 21c26c3

3 files changed

Lines changed: 301 additions & 2 deletions

File tree

api/documents.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from fastapi import APIRouter, HTTPException, Depends, Query, BackgroundTasks
6+
from typing import Optional, Literal
67
from src.multi_tenant import get_tenant_lightrag
78
from src.tenant_deps import get_tenant_id
89
from src.logger import logger
@@ -185,3 +186,156 @@ async def delete_document(
185186
"doc_id": doc_id,
186187
"tenant_id": tenant_id
187188
}
189+
190+
191+
# ============ GET 文档列表 ============
192+
193+
@router.get("/documents")
194+
async def list_documents(
195+
tenant_id: str = Depends(get_tenant_id),
196+
page: int = Query(1, ge=1, le=10000, description="页码(从 1 开始)"),
197+
page_size: int = Query(50, ge=1, le=100, description="每页数量(最多 100)"),
198+
status_filter: Optional[Literal["pending", "processing", "preprocessed", "processed", "failed"]] = None,
199+
sort_field: Literal["created_at", "updated_at"] = Query("created_at"),
200+
sort_direction: Literal["asc", "desc"] = Query("desc")
201+
):
202+
"""
203+
获取租户的文档列表(支持分页、过滤、排序)
204+
205+
**功能**:
206+
- ✅ 分页:page, page_size
207+
- ✅ 过滤:status_filter (pending/processing/preprocessed/processed/failed)
208+
- ✅ 排序:sort_field (created_at/updated_at), sort_direction (asc/desc)
209+
- ✅ 使用 LightRAG 原生分页 API
210+
211+
**示例请求**:
212+
- GET /documents?tenant_id=tenant_a&page=1&page_size=20
213+
- GET /documents?tenant_id=tenant_a&status_filter=processed&sort_field=updated_at&sort_direction=desc
214+
215+
**示例响应**:
216+
```json
217+
{
218+
"documents": [
219+
{
220+
"doc_id": "doc-abc123",
221+
"status": "processed",
222+
"file_path": "research_paper.pdf",
223+
"created_at": "2025-11-06T10:00:00",
224+
"updated_at": "2025-11-06T10:05:00"
225+
}
226+
],
227+
"pagination": {
228+
"total": 100,
229+
"page": 1,
230+
"page_size": 20,
231+
"total_pages": 5,
232+
"has_next": true,
233+
"has_prev": false
234+
}
235+
}
236+
```
237+
"""
238+
try:
239+
# 获取 LightRAG 实例
240+
lightrag = await get_tenant_lightrag(tenant_id)
241+
242+
# 验证 doc_status 是否可用
243+
if not hasattr(lightrag, 'doc_status'):
244+
raise HTTPException(
245+
status_code=501,
246+
detail="Document status storage not available"
247+
)
248+
249+
# 调用 LightRAG 的分页方法
250+
docs_list, total = await lightrag.doc_status.get_docs_paginated(
251+
status_filter=status_filter,
252+
page=page,
253+
page_size=page_size,
254+
sort_field=sort_field,
255+
sort_direction=sort_direction
256+
)
257+
258+
# 格式化文档数据
259+
documents = []
260+
for doc in docs_list:
261+
# docs_list 是列表,每个元素是文档对象
262+
if hasattr(doc, '__dict__'):
263+
doc_dict = doc.__dict__.copy()
264+
elif isinstance(doc, dict):
265+
doc_dict = doc.copy()
266+
else:
267+
# 尝试转换为字典
268+
doc_dict = {"raw_data": str(doc)}
269+
270+
documents.append(doc_dict)
271+
272+
return {
273+
"documents": documents,
274+
"pagination": {
275+
"total": total,
276+
"page": page,
277+
"page_size": page_size,
278+
"total_pages": (total + page_size - 1) // page_size if total > 0 else 0,
279+
"has_next": page * page_size < total,
280+
"has_prev": page > 1
281+
}
282+
}
283+
284+
except NotImplementedError as e:
285+
logger.error(f"get_docs_paginated not implemented: {e}")
286+
raise HTTPException(
287+
status_code=501,
288+
detail="Document pagination not implemented in current LightRAG version"
289+
)
290+
except Exception as e:
291+
logger.error(f"Failed to list documents for tenant {tenant_id}: {e}", exc_info=True)
292+
raise HTTPException(
293+
status_code=500,
294+
detail=f"Failed to retrieve documents: {str(e)}"
295+
)
296+
297+
298+
@router.get("/documents/status_counts")
299+
async def get_document_status_counts(tenant_id: str = Depends(get_tenant_id)):
300+
"""
301+
获取文档状态统计
302+
303+
**功能**:
304+
- 返回各状态的文档数量(pending/processing/preprocessed/processed/failed)
305+
306+
**示例响应**:
307+
```json
308+
{
309+
"status_counts": {
310+
"pending": 5,
311+
"processing": 2,
312+
"preprocessed": 3,
313+
"processed": 100,
314+
"failed": 1,
315+
"all": 111
316+
}
317+
}
318+
```
319+
"""
320+
try:
321+
# 获取 LightRAG 实例
322+
lightrag = await get_tenant_lightrag(tenant_id)
323+
324+
# 验证 doc_status 是否可用
325+
if not hasattr(lightrag, 'doc_status'):
326+
raise HTTPException(
327+
status_code=501,
328+
detail="Document status storage not available"
329+
)
330+
331+
# 调用 LightRAG 的统计方法
332+
counts = await lightrag.doc_status.get_all_status_counts()
333+
334+
return {"status_counts": counts}
335+
336+
except Exception as e:
337+
logger.error(f"Failed to get status counts for tenant {tenant_id}: {e}")
338+
raise HTTPException(
339+
status_code=500,
340+
detail="Failed to get status counts"
341+
)

api/task.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
"""
44

55
from datetime import datetime
6-
from fastapi import APIRouter, HTTPException, Depends
6+
from typing import Optional, Literal
7+
from fastapi import APIRouter, HTTPException, Depends, Query
78

89
from src.logger import logger
910
from src.tenant_deps import get_tenant_id
10-
from .task_store import get_task, update_task
11+
from .task_store import get_task, update_task, get_tenant_tasks
1112
from .models import TaskStatus, TaskInfo
1213

1314
router = APIRouter()
@@ -184,3 +185,132 @@ async def sync_task_with_lightrag(task: TaskInfo, tenant_id: str) -> TaskInfo:
184185
)
185186

186187
return task
188+
189+
190+
@router.get("/tasks")
191+
async def list_tasks(
192+
tenant_id: str = Depends(get_tenant_id),
193+
status: Optional[Literal["pending", "processing", "completed", "failed"]] = None,
194+
page: int = Query(1, ge=1, le=10000, description="页码(从 1 开始)"),
195+
page_size: int = Query(50, ge=1, le=100, description="每页数量(最多 100)"),
196+
sort_by: Literal["created_at", "updated_at", "status"] = Query("created_at"),
197+
sort_order: Literal["asc", "desc"] = Query("desc")
198+
):
199+
"""
200+
获取租户的任务列表(支持分页、过滤、排序)
201+
202+
**功能**:
203+
- ✅ 分页:page, page_size
204+
- ✅ 过滤:status (pending/processing/completed/failed)
205+
- ✅ 排序:sort_by (created_at/updated_at/status), sort_order (asc/desc)
206+
207+
**注意**:
208+
- 当前在内存中分页,如果任务量 >10000,性能会下降
209+
- 建议未来在存储层实现分页
210+
211+
**示例请求**:
212+
- GET /tasks?tenant_id=tenant_a&page=1&page_size=20
213+
- GET /tasks?tenant_id=tenant_a&status=completed&sort_by=updated_at&sort_order=desc
214+
215+
**示例响应**:
216+
```json
217+
{
218+
"tasks": [
219+
{
220+
"task_id": "xxx",
221+
"tenant_id": "tenant_a",
222+
"status": "completed",
223+
"doc_id": "doc_001",
224+
"filename": "test.pdf",
225+
"created_at": "2025-10-14T20:00:00",
226+
"updated_at": "2025-10-14T20:02:30"
227+
}
228+
],
229+
"pagination": {
230+
"total": 100,
231+
"page": 1,
232+
"page_size": 20,
233+
"total_pages": 5,
234+
"has_next": true,
235+
"has_prev": false
236+
}
237+
}
238+
```
239+
"""
240+
try:
241+
# 获取所有任务
242+
tasks_dict = get_tenant_tasks(tenant_id)
243+
244+
# 如果没有任务,返回空列表
245+
if not tasks_dict:
246+
return {
247+
"tasks": [],
248+
"pagination": {
249+
"total": 0,
250+
"page": page,
251+
"page_size": page_size,
252+
"total_pages": 0,
253+
"has_next": False,
254+
"has_prev": False
255+
}
256+
}
257+
258+
tasks_list = list(tasks_dict.values())
259+
260+
# 过滤状态
261+
if status:
262+
tasks_list = [t for t in tasks_list if t.status.value == status]
263+
264+
# 排序
265+
reverse = (sort_order == "desc")
266+
tasks_list.sort(
267+
key=lambda t: getattr(t, sort_by, 0) or 0,
268+
reverse=reverse
269+
)
270+
271+
# 分页
272+
total = len(tasks_list)
273+
start = (page - 1) * page_size
274+
end = start + page_size
275+
tasks_page = tasks_list[start:end]
276+
277+
# 转换为 dict(确保可序列化)
278+
tasks_data = []
279+
for t in tasks_page:
280+
if hasattr(t, 'dict'):
281+
tasks_data.append(t.dict())
282+
else:
283+
# 手动转换为字典
284+
task_dict = {
285+
"task_id": t.task_id,
286+
"tenant_id": t.tenant_id,
287+
"status": t.status.value,
288+
"doc_id": t.doc_id,
289+
"filename": t.filename,
290+
"created_at": t.created_at,
291+
"updated_at": t.updated_at
292+
}
293+
if hasattr(t, 'result') and t.result:
294+
task_dict["result"] = t.result
295+
if hasattr(t, 'error') and t.error:
296+
task_dict["error"] = t.error
297+
tasks_data.append(task_dict)
298+
299+
return {
300+
"tasks": tasks_data,
301+
"pagination": {
302+
"total": total,
303+
"page": page,
304+
"page_size": page_size,
305+
"total_pages": (total + page_size - 1) // page_size,
306+
"has_next": end < total,
307+
"has_prev": page > 1
308+
}
309+
}
310+
311+
except Exception as e:
312+
logger.error(f"Failed to list tasks for tenant {tenant_id}: {e}")
313+
raise HTTPException(
314+
status_code=500,
315+
detail="Failed to retrieve tasks"
316+
)

main.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from fastapi import FastAPI
9+
from fastapi.middleware.cors import CORSMiddleware
910

1011
# 导入 RAG 相关模块
1112
from src.rag import lifespan
@@ -94,6 +95,20 @@
9495
]
9596
)
9697

98+
# 添加 CORS middleware(修复前端跨域问题)
99+
app.add_middleware(
100+
CORSMiddleware,
101+
allow_origins=[
102+
"https://main.d2bxt3tjxqfsjq.amplifyapp.com", # 前端生产域名
103+
"http://localhost:3000", # 本地开发(React)
104+
"http://localhost:5173", # 本地开发(Vite)
105+
],
106+
allow_credentials=False, # 不发送 cookies,降低安全风险
107+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
108+
allow_headers=["Content-Type", "Authorization", "X-Request-ID"],
109+
max_age=3600, # 预检请求缓存 1 小时
110+
)
111+
97112
# 注册 API 路由
98113
app.include_router(api_router)
99114

0 commit comments

Comments
 (0)