diff --git a/README.md b/README.md index da31e11..8af0cdf 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ - 🌐 **代理支持** - 支持 HTTP/SOCKS5 代理 - 📱 **Web 管理界面** - 直观的 Token 和配置管理 - 🎨 **图片生成连续对话** +- 🧩 **Gemini 官方请求体兼容** - 支持 `generateContent` / `streamGenerateContent`、`systemInstruction`、`contents.parts.text/inlineData/fileData` +- ✅ **Gemini 官方格式已实测出图** - 已使用真实 Token 验证 `/models/{model}:generateContent` 可正常返回官方 `candidates[].content.parts[].inlineData` ## 🚀 快速开始 @@ -238,6 +240,63 @@ python main.py ## 📡 API 使用示例(需要使用流式) +> 除了下方 `OpenAI-compatible` 示例,服务也支持 Gemini 官方格式: +> - `POST /v1beta/models/{model}:generateContent` +> - `POST /models/{model}:generateContent` +> - `POST /v1beta/models/{model}:streamGenerateContent` +> - `POST /models/{model}:streamGenerateContent` +> +> Gemini 官方格式支持以下认证方式: +> - `Authorization: Bearer ` +> - `x-goog-api-key: ` +> - `?key=` +> +> Gemini 官方图片请求体已兼容: +> - `systemInstruction` +> - `contents[].parts[].text` +> - `contents[].parts[].inlineData` +> - `contents[].parts[].fileData.fileUri` +> - `generationConfig.responseModalities` +> - `generationConfig.imageConfig.aspectRatio` +> - `generationConfig.imageConfig.imageSize` + +### Gemini 官方 generateContent(文生图) + +> 已使用真实 Token 实测通过。 +> 如需流式返回,可将路径替换为 `:streamGenerateContent?alt=sse`。 + +```bash +curl -X POST "http://localhost:8000/models/gemini-3.1-flash-image:generateContent" \ + -H "x-goog-api-key: han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "systemInstruction": { + "parts": [ + { + "text": "Return an image only." + } + ] + }, + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "一颗放在木桌上的红苹果,棚拍光线,极简背景" + } + ] + } + ], + "generationConfig": { + "responseModalities": ["IMAGE"], + "imageConfig": { + "aspectRatio": "1:1", + "imageSize": "1K" + } + } + }' +``` + ### 文生图 ```bash diff --git a/src/api/routes.py b/src/api/routes.py index d243d0f..71865cf 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -1,73 +1,590 @@ -"""API routes - OpenAI compatible endpoints""" +"""API routes for OpenAI-compatible and Gemini generateContent endpoints.""" -from fastapi import APIRouter, Depends, HTTPException -from fastapi.responses import StreamingResponse, JSONResponse -from typing import List, Optional +from dataclasses import dataclass +from typing import Any, Dict, List, Optional import base64 -import re import json -import time +import mimetypes +import re from urllib.parse import urlparse + from curl_cffi.requests import AsyncSession -from ..core.auth import verify_api_key_header -from ..core.models import ChatCompletionRequest -from ..services.generation_handler import GenerationHandler, MODEL_CONFIG +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import JSONResponse, StreamingResponse + +from ..core.auth import verify_api_key_flexible from ..core.logger import debug_logger -from ..core.model_resolver import resolve_model_name, get_base_model_aliases +from ..core.model_resolver import get_base_model_aliases, resolve_model_name +from ..core.models import ( + ChatCompletionRequest, + ChatMessage, + GeminiContent, + GeminiGenerateContentRequest, +) +from ..services.generation_handler import MODEL_CONFIG, GenerationHandler router = APIRouter() +MARKDOWN_IMAGE_RE = re.compile(r"!\[.*?\]\((.*?)\)") +HTML_VIDEO_RE = re.compile(r"]+src=['\"](.*?)['\"]", re.IGNORECASE) +DATA_URL_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$", re.DOTALL) +GEMINI_STATUS_MAP = { + 400: "INVALID_ARGUMENT", + 401: "UNAUTHENTICATED", + 403: "PERMISSION_DENIED", + 404: "NOT_FOUND", + 409: "ABORTED", + 429: "RESOURCE_EXHAUSTED", + 500: "INTERNAL", + 502: "UNAVAILABLE", + 503: "UNAVAILABLE", + 504: "DEADLINE_EXCEEDED", +} + # Dependency injection will be set up in main.py generation_handler: GenerationHandler = None +@dataclass +class NormalizedGenerationRequest: + """Internal request shape shared by OpenAI and Gemini entrypoints.""" + + model: str + prompt: str + images: List[bytes] + messages: Optional[List[ChatMessage]] = None + + def set_generation_handler(handler: GenerationHandler): - """Set generation handler instance""" + """Set generation handler instance.""" global generation_handler generation_handler = handler +def _ensure_generation_handler() -> GenerationHandler: + if generation_handler is None: + raise HTTPException(status_code=500, detail="Generation handler not initialized") + return generation_handler + + +def _decode_data_url(data_url: str) -> tuple[str, bytes]: + match = DATA_URL_RE.match(data_url) + if not match: + raise HTTPException(status_code=400, detail="Invalid data URL") + return match.group("mime"), base64.b64decode(match.group("data")) + + +def _detect_image_mime_type(image_bytes: bytes, fallback: str = "image/png") -> str: + if image_bytes.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if image_bytes.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if image_bytes.startswith(b"GIF87a") or image_bytes.startswith(b"GIF89a"): + return "image/gif" + if image_bytes.startswith(b"RIFF") and image_bytes[8:12] == b"WEBP": + return "image/webp" + return fallback + + +def _guess_mime_type(uri: str, fallback: str) -> str: + guessed, _ = mimetypes.guess_type(urlparse(uri).path) + return guessed or fallback + + async def retrieve_image_data(url: str) -> Optional[bytes]: - """ - 智能获取图片数据: - 1. 优先检查是否为本地 /tmp/ 缓存文件,如果是则直接读取磁盘 - 2. 如果本地不存在或是外部链接,则进行网络下载 - """ - # 优先尝试本地读取 + """Read image bytes from local /tmp cache or remote URL.""" try: - if "/tmp/" in url and generation_handler and generation_handler.file_cache: + file_cache = getattr(generation_handler, "file_cache", None) + if "/tmp/" in url and file_cache: path = urlparse(url).path filename = path.split("/tmp/")[-1] - local_file_path = generation_handler.file_cache.cache_dir / filename + local_file_path = file_cache.cache_dir / filename if local_file_path.exists() and local_file_path.is_file(): data = local_file_path.read_bytes() if data: return data - except Exception as e: - debug_logger.log_warning(f"[CONTEXT] 本地缓存读取失败: {str(e)}") + except Exception as exc: + debug_logger.log_warning(f"[CONTEXT] 本地缓存读取失败: {str(exc)}") - # 回退逻辑:网络下载 try: async with AsyncSession() as session: response = await session.get( - url, timeout=30, impersonate="chrome110", verify=False + url, + timeout=30, + impersonate="chrome110", + verify=False, ) if response.status_code == 200: return response.content - else: - debug_logger.log_warning( - f"[CONTEXT] 图片下载失败,状态码: {response.status_code}" - ) - except Exception as e: - debug_logger.log_error(f"[CONTEXT] 图片下载异常: {str(e)}") + debug_logger.log_warning( + f"[CONTEXT] 图片下载失败,状态码: {response.status_code}" + ) + except Exception as exc: + debug_logger.log_error(f"[CONTEXT] 图片下载异常: {str(exc)}") return None +async def _load_image_bytes_from_uri(uri: str) -> bytes: + if not uri: + raise HTTPException(status_code=400, detail="Image URI cannot be empty") + + if uri.startswith("data:image"): + _, image_bytes = _decode_data_url(uri) + return image_bytes + + if uri.startswith("http://") or uri.startswith("https://") or "/tmp/" in uri: + image_bytes = await retrieve_image_data(uri) + if image_bytes: + return image_bytes + raise HTTPException(status_code=400, detail=f"Failed to load image from {uri}") + + raise HTTPException(status_code=400, detail=f"Unsupported image URI: {uri}") + + +def _coerce_gemini_contents(raw_contents: Optional[List[Any]]) -> List[GeminiContent]: + contents: List[GeminiContent] = [] + for item in raw_contents or []: + if isinstance(item, GeminiContent): + contents.append(item) + else: + contents.append(GeminiContent.model_validate(item)) + return contents + + +def _extract_text_from_gemini_content(content: Optional[GeminiContent]) -> str: + if content is None: + return "" + text_parts = [part.text.strip() for part in content.parts if part.text] + return "\n".join(part for part in text_parts if part).strip() + + +async def _extract_prompt_and_images_from_openai_messages( + messages: List[ChatMessage], +) -> tuple[str, List[bytes]]: + last_message = messages[-1] + content = last_message.content + prompt_parts: List[str] = [] + images: List[bytes] = [] + + if isinstance(content, str): + prompt_parts.append(content) + elif isinstance(content, list): + for item in content: + item_type = item.get("type") + if item_type == "text": + text = item.get("text", "").strip() + if text: + prompt_parts.append(text) + elif item_type == "image_url": + image_url = item.get("image_url", {}).get("url", "") + images.append(await _load_image_bytes_from_uri(image_url)) + + prompt = "\n".join(part for part in prompt_parts if part).strip() + return prompt, images + + +async def _append_openai_reference_images( + model: str, + messages: List[ChatMessage], + images: List[bytes], +) -> List[bytes]: + model_config = MODEL_CONFIG.get(model) + if not model_config or model_config["type"] != "image" or len(messages) <= 1: + return images + + debug_logger.log_info(f"[CONTEXT] 开始查找历史参考图,消息数量: {len(messages)}") + + for msg in reversed(messages[:-1]): + if msg.role == "assistant" and isinstance(msg.content, str): + matches = MARKDOWN_IMAGE_RE.findall(msg.content) + if not matches: + continue + + for image_url in reversed(matches): + if not image_url.startswith("http") and "/tmp/" not in image_url: + continue + try: + downloaded_bytes = await retrieve_image_data(image_url) + if downloaded_bytes: + images.insert(0, downloaded_bytes) + debug_logger.log_info( + f"[CONTEXT] ✅ 添加历史参考图: {image_url}" + ) + return images + debug_logger.log_warning( + f"[CONTEXT] 图片下载失败或为空,尝试下一个: {image_url}" + ) + except Exception as exc: + debug_logger.log_error( + f"[CONTEXT] 处理参考图时出错: {str(exc)}" + ) + return images + + +async def _extract_prompt_and_images_from_gemini_contents( + contents: List[GeminiContent], +) -> tuple[str, List[bytes]]: + if not contents: + raise HTTPException(status_code=400, detail="contents cannot be empty") + + target_content = next( + (content for content in reversed(contents) if (content.role or "user") == "user"), + contents[-1], + ) + + prompt_parts: List[str] = [] + images: List[bytes] = [] + + for part in target_content.parts: + if part.text: + text = part.text.strip() + if text: + prompt_parts.append(text) + elif part.inlineData is not None: + mime_type = part.inlineData.mimeType.lower() + if not mime_type.startswith("image/"): + raise HTTPException( + status_code=400, + detail=f"Unsupported inlineData mime type: {part.inlineData.mimeType}", + ) + images.append(base64.b64decode(part.inlineData.data)) + elif part.fileData is not None: + mime_type = (part.fileData.mimeType or "").lower() + if mime_type and not mime_type.startswith("image/"): + raise HTTPException( + status_code=400, + detail=f"Unsupported fileData mime type: {part.fileData.mimeType}", + ) + images.append(await _load_image_bytes_from_uri(part.fileData.fileUri)) + + prompt = "\n".join(part for part in prompt_parts if part).strip() + return prompt, images + + +def _resolve_request_model(model: str, request: Any) -> str: + resolved_model = resolve_model_name(model=model, request=request, model_config=MODEL_CONFIG) + if resolved_model != model: + debug_logger.log_info(f"[ROUTE] 模型名已转换: {model} → {resolved_model}") + return resolved_model + + +async def _normalize_openai_request( + request: ChatCompletionRequest, +) -> NormalizedGenerationRequest: + if request.messages: + prompt, images = await _extract_prompt_and_images_from_openai_messages( + request.messages + ) + if request.image and not images: + images.append(await _load_image_bytes_from_uri(request.image)) + model = _resolve_request_model(request.model, request) + images = await _append_openai_reference_images(model, request.messages, images) + return NormalizedGenerationRequest( + model=model, + prompt=prompt, + images=images, + messages=request.messages, + ) + + if request.contents: + gemini_request = GeminiGenerateContentRequest( + contents=_coerce_gemini_contents(request.contents), + generationConfig=request.generationConfig, + ) + normalized = await _normalize_gemini_request(request.model, gemini_request) + normalized.messages = request.messages + return normalized + + raise HTTPException(status_code=400, detail="Messages or contents cannot be empty") + + +async def _normalize_gemini_request( + model: str, + request: GeminiGenerateContentRequest, +) -> NormalizedGenerationRequest: + prompt, images = await _extract_prompt_and_images_from_gemini_contents(request.contents) + system_instruction = _extract_text_from_gemini_content(request.systemInstruction) + if system_instruction: + prompt = f"{system_instruction}\n\n{prompt}".strip() + + return NormalizedGenerationRequest( + model=_resolve_request_model(model, request), + prompt=prompt, + images=images, + ) + + +async def _collect_non_stream_result( + model: str, + prompt: str, + images: List[bytes], +) -> str: + handler = _ensure_generation_handler() + result = None + async for chunk in handler.handle_generation( + model=model, + prompt=prompt, + images=images if images else None, + stream=False, + ): + result = chunk + + if result is None: + raise HTTPException(status_code=500, detail="Generation failed: No response") + + return result + + +def _parse_handler_result(result: str) -> Dict[str, Any]: + try: + return json.loads(result) + except json.JSONDecodeError: + return {"result": result} + + +def _get_error_status_code(payload: Dict[str, Any]) -> int: + error = payload.get("error") + if isinstance(error, dict): + status_code = error.get("status_code") + if isinstance(status_code, int): + return status_code + if isinstance(status_code, str) and status_code.isdigit(): + return int(status_code) + return 400 + return 200 + + +def _build_openai_json_response(payload: Dict[str, Any]) -> JSONResponse: + return JSONResponse(content=payload, status_code=_get_error_status_code(payload)) + + +def _build_gemini_error_payload(status_code: int, message: str) -> Dict[str, Any]: + return { + "error": { + "code": status_code, + "message": message, + "status": GEMINI_STATUS_MAP.get(status_code, "UNKNOWN"), + } + } + + +def _build_gemini_error_response_from_handler(payload: Dict[str, Any]) -> JSONResponse: + error = payload.get("error", {}) + status_code = _get_error_status_code(payload) + message = error.get("message", "Generation failed") + return JSONResponse( + status_code=status_code, + content=_build_gemini_error_payload(status_code, message), + ) + + +def _extract_openai_message_content(payload: Dict[str, Any]) -> str: + choices = payload.get("choices", []) + if not choices: + return payload.get("result", "") + + message = choices[0].get("message", {}) + content = message.get("content", "") + return content if isinstance(content, str) else "" + + +async def _build_image_parts_from_uri(uri: str) -> List[Dict[str, Any]]: + if uri.startswith("data:image"): + mime_type, _ = _decode_data_url(uri) + match = DATA_URL_RE.match(uri) + if match: + return [{"inlineData": {"mimeType": mime_type, "data": match.group("data")}}] + + image_bytes = await retrieve_image_data(uri) + if image_bytes: + mime_type = _detect_image_mime_type( + image_bytes, + fallback=_guess_mime_type(uri, "image/png"), + ) + return [ + { + "inlineData": { + "mimeType": mime_type, + "data": base64.b64encode(image_bytes).decode("ascii"), + } + } + ] + + return [ + { + "fileData": { + "mimeType": _guess_mime_type(uri, "image/png"), + "fileUri": uri, + } + } + ] + + +def _build_video_parts_from_uri(uri: str) -> List[Dict[str, Any]]: + return [ + { + "fileData": { + "mimeType": _guess_mime_type(uri, "video/mp4"), + "fileUri": uri, + } + } + ] + + +async def _build_gemini_parts_from_output(output: str) -> List[Dict[str, Any]]: + if not output: + return [] + + image_matches = MARKDOWN_IMAGE_RE.findall(output) + if image_matches: + parts: List[Dict[str, Any]] = [] + for uri in image_matches: + parts.extend(await _build_image_parts_from_uri(uri)) + return parts + + video_matches = HTML_VIDEO_RE.findall(output) + if video_matches: + parts: List[Dict[str, Any]] = [] + for uri in video_matches: + parts.extend(_build_video_parts_from_uri(uri)) + return parts + + return [{"text": output}] + + +async def _build_gemini_success_payload( + payload: Dict[str, Any], + response_model: str, +) -> Dict[str, Any]: + output = _extract_openai_message_content(payload) + return { + "candidates": [ + { + "content": { + "role": "model", + "parts": await _build_gemini_parts_from_output(output), + }, + "finishReason": "STOP", + "index": 0, + } + ], + "modelVersion": response_model, + } + + +def _normalize_finish_reason(reason: Optional[str]) -> Optional[str]: + if reason is None: + return None + mapping = { + "stop": "STOP", + "length": "MAX_TOKENS", + "content_filter": "SAFETY", + } + return mapping.get(reason, "STOP") + + +async def _convert_openai_stream_chunk_to_gemini_event( + payload: Dict[str, Any], + response_model: str, +) -> Optional[str]: + choices = payload.get("choices", []) + if not choices: + return None + + choice = choices[0] + delta = choice.get("delta", {}) + text = delta.get("reasoning_content") or delta.get("content") or "" + finish_reason = _normalize_finish_reason(choice.get("finish_reason")) + + candidate: Dict[str, Any] = {"index": choice.get("index", 0)} + if text: + candidate["content"] = { + "role": "model", + "parts": await _build_gemini_parts_from_output(text), + } + if finish_reason: + candidate["finishReason"] = finish_reason + + if len(candidate) == 1: + return None + + chunk = { + "candidates": [candidate], + "modelVersion": response_model, + } + return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" + + +async def _iterate_openai_stream( + normalized: NormalizedGenerationRequest, +): + handler = _ensure_generation_handler() + async for chunk in handler.handle_generation( + model=normalized.model, + prompt=normalized.prompt, + images=normalized.images if normalized.images else None, + stream=True, + ): + if chunk.startswith("data: "): + yield chunk + continue + + payload = _parse_handler_result(chunk) + yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + + yield "data: [DONE]\n\n" + + +async def _iterate_gemini_stream( + normalized: NormalizedGenerationRequest, + response_model: str, +): + handler = _ensure_generation_handler() + async for chunk in handler.handle_generation( + model=normalized.model, + prompt=normalized.prompt, + images=normalized.images if normalized.images else None, + stream=True, + ): + if chunk.startswith("data: "): + payload_text = chunk[6:].strip() + if payload_text == "[DONE]": + continue + payload = _parse_handler_result(payload_text) + if "error" in payload: + yield ( + f"data: {json.dumps(_build_gemini_error_payload(_get_error_status_code(payload), payload['error'].get('message', 'Generation failed')), ensure_ascii=False)}\n\n" + ) + return + + event = await _convert_openai_stream_chunk_to_gemini_event( + payload, + response_model, + ) + if event: + yield event + continue + + payload = _parse_handler_result(chunk) + if "error" in payload: + yield ( + f"data: {json.dumps(_build_gemini_error_payload(_get_error_status_code(payload), payload['error'].get('message', 'Generation failed')), ensure_ascii=False)}\n\n" + ) + return + + event = await _convert_openai_stream_chunk_to_gemini_event( + payload, + response_model, + ) + if event: + yield event + + @router.get("/v1/models") -async def list_models(api_key: str = Depends(verify_api_key_header)): - """List available models""" +async def list_models(api_key: str = Depends(verify_api_key_flexible)): + """List available models.""" models = [] for model_id, config in MODEL_CONFIG.items(): @@ -90,8 +607,8 @@ async def list_models(api_key: str = Depends(verify_api_key_header)): @router.get("/v1/models/aliases") -async def list_model_aliases(api_key: str = Depends(verify_api_key_header)): - """List simplified model name aliases that can be used with generationConfig""" +async def list_model_aliases(api_key: str = Depends(verify_api_key_flexible)): + """List simplified model aliases for generationConfig-based resolution.""" aliases = get_base_model_aliases() alias_models = [] for alias_id, description in aliases.items(): @@ -109,141 +626,18 @@ async def list_model_aliases(api_key: str = Depends(verify_api_key_header)): @router.post("/v1/chat/completions") async def create_chat_completion( - request: ChatCompletionRequest, api_key: str = Depends(verify_api_key_header) + request: ChatCompletionRequest, + api_key: str = Depends(verify_api_key_flexible), ): - """Create chat completion (unified endpoint for image and video generation)""" + """OpenAI-compatible unified generation endpoint.""" try: - # ── 模型名解析:基于 generationConfig 参数转换简化模型名 ── - original_model = request.model - request.model = resolve_model_name( - model=request.model, request=request, model_config=MODEL_CONFIG - ) - if request.model != original_model: - debug_logger.log_info( - f"[ROUTE] 模型名已转换: {original_model} → {request.model}" - ) - - # Extract prompt from messages - if not request.messages: - raise HTTPException(status_code=400, detail="Messages cannot be empty") - - last_message = request.messages[-1] - content = last_message.content - - # Handle both string and array format (OpenAI multimodal) - prompt = "" - images: List[bytes] = [] - - if isinstance(content, str): - # Simple text format - prompt = content - elif isinstance(content, list): - # Multimodal format - for item in content: - if item.get("type") == "text": - prompt = item.get("text", "") - elif item.get("type") == "image_url": - # Extract image from URL or base64 - image_url = item.get("image_url", {}).get("url", "") - if image_url.startswith("data:image"): - # Parse base64 - match = re.search(r"base64,(.+)", image_url) - if match: - image_base64 = match.group(1) - image_bytes = base64.b64decode(image_base64) - images.append(image_bytes) - elif image_url.startswith("http://") or image_url.startswith( - "https://" - ): - # Download remote image URL - debug_logger.log_info(f"[IMAGE_URL] 下载远程图片: {image_url}") - try: - downloaded_bytes = await retrieve_image_data(image_url) - if downloaded_bytes and len(downloaded_bytes) > 0: - images.append(downloaded_bytes) - debug_logger.log_info( - f"[IMAGE_URL] ✅ 远程图片下载成功: {len(downloaded_bytes)} 字节" - ) - else: - debug_logger.log_warning( - f"[IMAGE_URL] ⚠️ 远程图片下载失败或为空: {image_url}" - ) - except Exception as e: - debug_logger.log_error( - f"[IMAGE_URL] ❌ 远程图片下载异常: {str(e)}" - ) - - # Fallback to deprecated image parameter - if request.image and not images: - if request.image.startswith("data:image"): - match = re.search(r"base64,(.+)", request.image) - if match: - image_base64 = match.group(1) - image_bytes = base64.b64decode(image_base64) - images.append(image_bytes) - - # 自动参考图:仅对图片模型生效 - model_config = MODEL_CONFIG.get(request.model) - - if ( - model_config - and model_config["type"] == "image" - and len(request.messages) > 1 - ): - debug_logger.log_info( - f"[CONTEXT] 开始查找历史参考图,消息数量: {len(request.messages)}" - ) - - # 查找上一次 assistant 回复的图片 - for msg in reversed(request.messages[:-1]): - if msg.role == "assistant" and isinstance(msg.content, str): - # 匹配 Markdown 图片格式: ![...](http...) - matches = re.findall(r"!\[.*?\]\((.*?)\)", msg.content) - if matches: - last_image_url = matches[-1] - - if last_image_url.startswith("http"): - try: - downloaded_bytes = await retrieve_image_data( - last_image_url - ) - if downloaded_bytes and len(downloaded_bytes) > 0: - # 将历史图片插入到最前面 - images.insert(0, downloaded_bytes) - debug_logger.log_info( - f"[CONTEXT] ✅ 添加历史参考图: {last_image_url}" - ) - break - else: - debug_logger.log_warning( - f"[CONTEXT] 图片下载失败或为空,尝试下一个: {last_image_url}" - ) - except Exception as e: - debug_logger.log_error( - f"[CONTEXT] 处理参考图时出错: {str(e)}" - ) - # 继续尝试下一个图片 - - if not prompt: + normalized = await _normalize_openai_request(request) + if not normalized.prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") - # Call generation handler if request.stream: - # Streaming response - async def generate(): - async for chunk in generation_handler.handle_generation( - model=request.model, - prompt=prompt, - images=images if images else None, - stream=True, - ): - yield chunk - - # Send [DONE] signal - yield "data: [DONE]\n\n" - return StreamingResponse( - generate(), + _iterate_openai_stream(normalized), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", @@ -251,32 +645,91 @@ async def generate(): "X-Accel-Buffering": "no", }, ) - else: - # Non-streaming response - result = None - async for chunk in generation_handler.handle_generation( - model=request.model, - prompt=prompt, - images=images if images else None, - stream=False, - ): - result = chunk - - if result: - # Parse the result JSON string - try: - result_json = json.loads(result) - return JSONResponse(content=result_json) - except json.JSONDecodeError: - # If not JSON, return as-is - return JSONResponse(content={"result": result}) - else: - raise HTTPException( - status_code=500, - detail="Generation failed: No response from handler", - ) + + payload = _parse_handler_result( + await _collect_non_stream_result( + normalized.model, + normalized.prompt, + normalized.images, + ) + ) + return _build_openai_json_response(payload) except HTTPException: raise - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) + + +@router.post("/v1beta/models/{model}:generateContent") +@router.post("/models/{model}:generateContent") +async def generate_content( + model: str, + request: GeminiGenerateContentRequest, + api_key: str = Depends(verify_api_key_flexible), +): + """Gemini official generateContent endpoint.""" + try: + normalized = await _normalize_gemini_request(model, request) + if not normalized.prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + + payload = _parse_handler_result( + await _collect_non_stream_result( + normalized.model, + normalized.prompt, + normalized.images, + ) + ) + if "error" in payload: + return _build_gemini_error_response_from_handler(payload) + + return JSONResponse( + content=await _build_gemini_success_payload(payload, model) + ) + + except HTTPException as exc: + return JSONResponse( + status_code=exc.status_code, + content=_build_gemini_error_payload(exc.status_code, str(exc.detail)), + ) + except Exception as exc: + return JSONResponse( + status_code=500, + content=_build_gemini_error_payload(500, str(exc)), + ) + + +@router.post("/v1beta/models/{model}:streamGenerateContent") +@router.post("/models/{model}:streamGenerateContent") +async def stream_generate_content( + model: str, + request: GeminiGenerateContentRequest, + alt: Optional[str] = Query(None), + api_key: str = Depends(verify_api_key_flexible), +): + """Gemini official streamGenerateContent endpoint.""" + try: + normalized = await _normalize_gemini_request(model, request) + if not normalized.prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + + return StreamingResponse( + _iterate_gemini_stream(normalized, model), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + except HTTPException as exc: + return JSONResponse( + status_code=exc.status_code, + content=_build_gemini_error_payload(exc.status_code, str(exc.detail)), + ) + except Exception as exc: + return JSONResponse( + status_code=500, + content=_build_gemini_error_payload(500, str(exc)), + ) diff --git a/src/core/auth.py b/src/core/auth.py index 0568573..0693e7a 100644 --- a/src/core/auth.py +++ b/src/core/auth.py @@ -1,11 +1,13 @@ """Authentication module""" + import bcrypt from typing import Optional -from fastapi import HTTPException, Security -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import Header, HTTPException, Query, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from .config import config security = HTTPBearer() +optional_security = HTTPBearer(auto_error=False) class AuthManager: """Authentication manager""" @@ -37,3 +39,24 @@ async def verify_api_key_header(credentials: HTTPAuthorizationCredentials = Secu if not AuthManager.verify_api_key(api_key): raise HTTPException(status_code=401, detail="Invalid API key") return api_key + + +async def verify_api_key_flexible( + credentials: Optional[HTTPAuthorizationCredentials] = Security(optional_security), + x_goog_api_key: Optional[str] = Header(None, alias="x-goog-api-key"), + key: Optional[str] = Query(None), +) -> str: + """Verify API key from Authorization header, x-goog-api-key header, or key query param.""" + api_key = None + + if credentials is not None: + api_key = credentials.credentials + elif x_goog_api_key: + api_key = x_goog_api_key + elif key: + api_key = key + + if not api_key or not AuthManager.verify_api_key(api_key): + raise HTTPException(status_code=401, detail="Invalid API key") + + return api_key diff --git a/src/core/model_resolver.py b/src/core/model_resolver.py index 1c0b925..7d4b753 100644 --- a/src/core/model_resolver.py +++ b/src/core/model_resolver.py @@ -84,6 +84,8 @@ # imageSize 归一化映射 IMAGE_SIZE_MAP = { + "1k": "1k", + "1K": "1k", "2k": "2k", "2K": "2k", "4k": "4k", @@ -185,6 +187,10 @@ def _extract_generation_params(request) -> Tuple[Optional[str], Optional[str]]: if gen_config is None and hasattr(request, "__pydantic_extra__"): extra = request.__pydantic_extra__ or {} gen_config_raw = extra.get("generationConfig") + if not isinstance(gen_config_raw, dict): + extra_body = extra.get("extra_body") or extra.get("extraBody") + if isinstance(extra_body, dict): + gen_config_raw = extra_body.get("generationConfig") if isinstance(gen_config_raw, dict): image_config_raw = gen_config_raw.get("imageConfig", {}) if isinstance(image_config_raw, dict): @@ -229,10 +235,6 @@ def resolve_model_name( Returns: 解析后的内部模型名 """ - # 如果已经是有效的 MODEL_CONFIG key,直接返回 - if model_config and model in model_config: - return model - # ────── 图片模型解析 ────── if model in IMAGE_BASE_MODELS: base = IMAGE_BASE_MODELS[model] @@ -257,7 +259,7 @@ def resolve_model_name( resolved = f"{base}-{aspect_ratio}" # 检查支持的 imageSize - if image_size: + if image_size and image_size != "1k": supported_sizes = MODEL_SUPPORTED_SIZES.get(base, []) if image_size in supported_sizes: resolved = f"{resolved}-{image_size}" @@ -306,6 +308,10 @@ def resolve_model_name( ) return model + # 如果已经是有效的 MODEL_CONFIG key,直接返回 + if model_config and model in model_config: + return model + # 未知模型名,原样返回(由下游 MODEL_CONFIG 校验报错) return model diff --git a/src/core/models.py b/src/core/models.py index 8b39230..08b478e 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -1,7 +1,7 @@ """Data models for Flow2API""" -from pydantic import BaseModel -from typing import Optional, List, Union, Any +from pydantic import BaseModel, ConfigDict +from typing import Optional, List, Union, Any, Literal from datetime import datetime @@ -219,15 +219,55 @@ class GenerationConfigParam(BaseModel): responseModalities: Optional[List[str]] = None # ["IMAGE", "TEXT"] imageConfig: Optional[ImageConfig] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") + + +class GeminiInlineData(BaseModel): + """Gemini inline binary data.""" + + mimeType: str + data: str + + +class GeminiFileData(BaseModel): + """Gemini file reference.""" + + fileUri: str + mimeType: Optional[str] = None + + +class GeminiPart(BaseModel): + """Gemini content part.""" + + text: Optional[str] = None + inlineData: Optional[GeminiInlineData] = None + fileData: Optional[GeminiFileData] = None + + model_config = ConfigDict(extra="allow") + + +class GeminiContent(BaseModel): + """Gemini content block.""" + + role: Optional[Literal["user", "model"]] = None + parts: List[GeminiPart] + + +class GeminiGenerateContentRequest(BaseModel): + """Gemini official generateContent request.""" + + contents: List[GeminiContent] + generationConfig: Optional[GenerationConfigParam] = None + systemInstruction: Optional[GeminiContent] = None + + model_config = ConfigDict(extra="allow") class ChatCompletionRequest(BaseModel): """Chat completion request (OpenAI compatible + Gemini extension)""" model: str - messages: List[ChatMessage] + messages: Optional[List[ChatMessage]] = None stream: bool = False temperature: Optional[float] = None max_tokens: Optional[int] = None @@ -238,5 +278,4 @@ class ChatCompletionRequest(BaseModel): generationConfig: Optional[GenerationConfigParam] = None contents: Optional[List[Any]] = None # Gemini native contents - class Config: - extra = "allow" # Allow extra fields like extra_body passthrough + model_config = ConfigDict(extra="allow") # Allow extra fields like extra_body passthrough diff --git a/src/services/flow_client.py b/src/services/flow_client.py index ce66214..3aebe73 100644 --- a/src/services/flow_client.py +++ b/src/services/flow_client.py @@ -6,6 +6,7 @@ import uuid import random import base64 +import ssl from typing import Dict, Any, Optional, List, Union from urllib.parse import quote import urllib.error @@ -308,8 +309,109 @@ async def _make_request( debug_logger.log_error(f"[API FAILED] Request Body: {json_data}") debug_logger.log_error(f"[API FAILED] Exception: {error_msg}") + if self._should_fallback_to_urllib(error_msg): + debug_logger.log_warning( + f"[HTTP FALLBACK] curl_cffi 请求失败,回退 urllib: {method.upper()} {url}" + ) + try: + return await asyncio.to_thread( + self._sync_json_request_via_urllib, + method.upper(), + url, + headers, + json_data, + proxy_url, + request_timeout, + ) + except Exception as fallback_error: + debug_logger.log_error( + f"[HTTP FALLBACK] urllib 回退也失败: {fallback_error}" + ) + raise Exception( + f"Flow API request failed: curl={error_msg}; urllib={fallback_error}" + ) + raise Exception(f"Flow API request failed: {error_msg}") + def _should_fallback_to_urllib(self, error_message: str) -> bool: + """判断是否应从 curl_cffi 回退到 urllib。""" + error_lower = (error_message or "").lower() + return any( + keyword in error_lower + for keyword in [ + "curl: (6)", + "curl: (7)", + "curl: (28)", + "curl: (35)", + "curl: (52)", + "curl: (56)", + "connection timed out", + "could not connect", + "failed to connect", + "ssl connect error", + "tls connect error", + "network is unreachable", + ] + ) + + def _sync_json_request_via_urllib( + self, + method: str, + url: str, + headers: Optional[Dict[str, Any]], + json_data: Optional[Dict[str, Any]], + proxy_url: Optional[str], + timeout: int, + ) -> Dict[str, Any]: + """使用 urllib 执行 JSON 请求,作为 curl_cffi 的网络回退。""" + request_headers = dict(headers or {}) + request_headers.setdefault("Accept", "application/json") + + data = None + if method.upper() != "GET" and json_data is not None: + data = json.dumps(json_data, ensure_ascii=False).encode("utf-8") + request_headers["Content-Type"] = "application/json" + + handlers = [urllib.request.HTTPSHandler(context=ssl.create_default_context())] + if proxy_url: + handlers.append( + urllib.request.ProxyHandler( + {"http": proxy_url, "https": proxy_url} + ) + ) + + opener = urllib.request.build_opener(*handlers) + request = urllib.request.Request( + url=url, + data=data, + headers=request_headers, + method=method.upper(), + ) + + try: + with opener.open( + request, + timeout=timeout, + ) as response: + payload = response.read() + status_code = int(response.getcode() or 0) + except urllib.error.HTTPError as exc: + payload = exc.read() if hasattr(exc, "read") else b"" + status_code = int(getattr(exc, "code", 500) or 500) + body_text = payload.decode("utf-8", errors="replace") + raise Exception(f"HTTP Error {status_code}: {body_text[:200]}") from exc + except Exception as exc: + raise Exception(str(exc)) from exc + + body_text = payload.decode("utf-8", errors="replace") + if status_code >= 400: + raise Exception(f"HTTP Error {status_code}: {body_text[:200]}") + + try: + return json.loads(body_text) if body_text else {} + except Exception as exc: + raise Exception(f"Invalid JSON response: {body_text[:200]}") from exc + def _is_timeout_error(self, error: Exception) -> bool: """判断是否为网络超时,便于快速失败重试。""" error_lower = str(error).lower() @@ -343,6 +445,10 @@ def _is_retryable_network_error(self, error_str: str) -> bool: "remote host closed connection", ]) + def _get_control_plane_timeout(self) -> int: + """控制轻量控制面请求的超时,避免认证/项目接口长时间挂起。""" + return max(5, min(int(self.timeout or 0) or 120, 10)) + async def _acquire_image_launch_gate( self, token_id: Optional[int], @@ -488,7 +594,8 @@ async def st_to_at(self, st: str) -> dict: method="GET", url=url, use_st=True, - st_token=st + st_token=st, + timeout=self._get_control_plane_timeout(), ) return result @@ -511,18 +618,45 @@ async def create_project(self, st: str, title: str) -> str: "toolName": "PINHOLE" } } + max_retries = max(2, min(4, int(getattr(config, "flow_max_retries", 3) or 3))) + request_timeout = max(self._get_control_plane_timeout(), min(self.timeout, 15)) + last_error: Optional[Exception] = None - result = await self._make_request( - method="POST", - url=url, - json_data=json_data, - use_st=True, - st_token=st - ) + for retry_attempt in range(max_retries): + try: + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_st=True, + st_token=st, + timeout=request_timeout, + ) + project_result = ( + result.get("result", {}) + .get("data", {}) + .get("json", {}) + .get("result", {}) + ) + project_id = project_result.get("projectId") + if not project_id: + raise Exception("Invalid project.createProject response: missing projectId") + return project_id + except Exception as e: + last_error = e + retry_reason = "网络超时" if self._is_timeout_error(e) else self._get_retry_reason(str(e)) + if retry_reason and retry_attempt < max_retries - 1: + debug_logger.log_warning( + f"[PROJECT] 创建项目失败,准备重试 ({retry_attempt + 2}/{max_retries}) " + f"title={title!r}, reason={retry_reason}: {e}" + ) + await asyncio.sleep(1) + continue + raise - # 解析返回的project_id - project_id = result["result"]["data"]["json"]["result"]["projectId"] - return project_id + if last_error is not None: + raise last_error + raise RuntimeError("创建项目失败") async def delete_project(self, st: str, project_id: str): """删除项目 @@ -543,7 +677,8 @@ async def delete_project(self, st: str, project_id: str): url=url, json_data=json_data, use_st=True, - st_token=st + st_token=st, + timeout=self._get_control_plane_timeout(), ) # ========== 余额查询 (使用AT) ========== @@ -565,7 +700,8 @@ async def get_credits(self, at: str) -> dict: method="GET", url=url, use_at=True, - at_token=at + at_token=at, + timeout=self._get_control_plane_timeout(), ) return result diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 0c22c69..7f27cef 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -786,7 +786,7 @@ async def handle_generation( if model not in MODEL_CONFIG: error_msg = f"不支持的模型: {model}" debug_logger.log_error(error_msg) - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=400) return model_config = MODEL_CONFIG[model] @@ -800,26 +800,6 @@ async def handle_generation( } debug_logger.log_info(f"[GENERATION] 开始生成 - 模型: {model}, 类型: {generation_type}, Prompt: {prompt[:50]}...") - # 非流式模式: 只检查可用性 - if not stream: - is_image = (generation_type == "image") - is_video = (generation_type == "video") - available = await self.check_token_availability(is_image, is_video) - - if available: - if is_image: - message = "所有Token可用于图片生成。请启用流式模式使用生成功能。" - else: - message = "所有Token可用于视频生成。请启用流式模式使用生成功能。" - else: - if is_image: - message = "没有可用的Token进行图片生成" - else: - message = "没有可用的Token进行视频生成" - - yield self._create_completion_response(message, is_availability_check=True) - return - # 向用户展示开始信息 if stream: yield self._create_stream_chunk( @@ -875,7 +855,7 @@ async def handle_generation( ) if stream: yield self._create_stream_chunk(f"❌ {error_msg}\n") - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=503) return debug_logger.log_info(f"[GENERATION] 已选择Token: {token.id} ({token.email})") @@ -907,7 +887,7 @@ async def handle_generation( debug_logger.log_error(f"[GENERATION] {error_msg}") if stream: yield self._create_stream_chunk(f"❌ {error_msg}\n") - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=503) return # 4. 确保Project存在 @@ -919,7 +899,7 @@ async def handle_generation( debug_logger.log_error(f"[GENERATION] {error_msg}") if stream: yield self._create_stream_chunk(f"❌ {error_msg}\n") - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=403) return ensure_project_started_at = time.time() @@ -985,7 +965,7 @@ async def handle_generation( if not generation_result.get("error_emitted"): if stream: yield self._create_stream_chunk(f"❌ {error_msg}\n") - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=500) return is_video = (generation_type == "video") @@ -1092,7 +1072,7 @@ async def handle_generation( ) if stream: yield self._create_stream_chunk(f"❌ {error_msg}\n") - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=500) finally: if pending_token_state.get("active") and token and self.load_balancer: await self.load_balancer.release_pending( @@ -1194,7 +1174,7 @@ async def _handle_image_generation( media = result.get("media", []) if not media: self._mark_generation_failed(generation_result, "\u751f\u6210\u7ed3\u679c\u4e3a\u7a7a") - yield self._create_error_response("生成结果为空") + yield self._create_error_response("生成结果为空", status_code=502) return image_url = media[0]["image"]["generatedImage"]["fifeUrl"] @@ -1461,7 +1441,7 @@ async def _handle_video_generation( if stream: yield self._create_stream_chunk(f"{error_msg}\n") self._mark_generation_failed(generation_result, error_msg) - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=400) return # R2V: 多图生成 - 当前上游协议最多 3 张参考图 @@ -1471,7 +1451,7 @@ async def _handle_video_generation( if stream: yield self._create_stream_chunk(f"{error_msg}\n") self._mark_generation_failed(generation_result, error_msg) - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=400) return # ========== 上传图片 ========== @@ -1591,7 +1571,7 @@ async def _handle_video_generation( operations = result.get("operations", []) if not operations: self._mark_generation_failed(generation_result, "\u751f\u6210\u4efb\u52a1\u521b\u5efa\u5931\u8d25") - yield self._create_error_response("生成任务创建失败") + yield self._create_error_response("生成任务创建失败", status_code=502) return operation = operations[0] @@ -1691,7 +1671,7 @@ async def _poll_video_result( error_msg = "视频生成失败: 视频URL为空" await self._fail_video_task(checked_operations, error_msg) self._mark_generation_failed(generation_result, error_msg) - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=502) return # ========== 视频放大处理 ========== @@ -1802,7 +1782,7 @@ async def _poll_video_result( self._mark_generation_failed(generation_result, friendly_error) if stream: yield self._create_stream_chunk(f"❌ {friendly_error}\n") - yield self._create_error_response(friendly_error) + yield self._create_error_response(friendly_error, status_code=502) return elif status.startswith("MEDIA_GENERATION_STATUS_ERROR"): @@ -1810,7 +1790,7 @@ async def _poll_video_result( error_msg = f"视频生成失败: {status}" await self._fail_video_task(checked_operations, error_msg) self._mark_generation_failed(generation_result, error_msg) - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=502) return except Exception as e: @@ -1823,7 +1803,7 @@ async def _poll_video_result( self._mark_generation_failed(generation_result, error_msg) if stream: yield self._create_stream_chunk(f"❌ {error_msg}\n") - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=502) return continue @@ -1834,7 +1814,7 @@ async def _poll_video_result( error_msg = f"视频生成超时 (已轮询 {max_attempts} 次)" await self._fail_video_task(operations, error_msg) self._mark_generation_failed(generation_result, error_msg) - yield self._create_error_response(error_msg) + yield self._create_error_response(error_msg, status_code=504) # ========== 响应格式化 ========== @@ -1906,15 +1886,16 @@ def _create_completion_response(self, content: str, media_type: str = "image", i return json.dumps(response, ensure_ascii=False) - def _create_error_response(self, error_message: str) -> str: + def _create_error_response(self, error_message: str, status_code: int = 500) -> str: """创建错误响应""" import json error = { "error": { "message": error_message, - "type": "invalid_request_error", - "code": "generation_failed" + "type": "server_error" if status_code >= 500 else "invalid_request_error", + "code": "generation_failed", + "status_code": status_code, } } diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c29a04d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,60 @@ +from pathlib import Path +import sys + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from src.api import routes +from src.core.auth import verify_api_key_flexible + + +class FakeGenerationHandler: + def __init__(self): + self.calls = [] + self.file_cache = None + self.non_stream_chunks = [] + self.stream_chunks = [] + + async def handle_generation(self, model, prompt, images=None, stream=False): + self.calls.append( + { + "model": model, + "prompt": prompt, + "images": images, + "stream": stream, + } + ) + chunks = self.stream_chunks if stream else self.non_stream_chunks + for chunk in chunks: + yield chunk + + +@pytest.fixture +def fake_handler(): + return FakeGenerationHandler() + + +@pytest.fixture +def fastapi_app(fake_handler): + app = FastAPI() + app.include_router(routes.router) + + async def fake_auth(): + return "test-api-key" + + app.dependency_overrides[verify_api_key_flexible] = fake_auth + routes.set_generation_handler(fake_handler) + yield app + app.dependency_overrides.clear() + routes.set_generation_handler(None) + + +@pytest.fixture +def client(fastapi_app): + with TestClient(fastapi_app) as test_client: + yield test_client diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py new file mode 100644 index 0000000..ec67b8a --- /dev/null +++ b/tests/test_api_routes.py @@ -0,0 +1,84 @@ +import asyncio +import json + +from src.api import routes +from src.core.auth import AuthManager, verify_api_key_flexible + + +def build_openai_completion(content: str) -> str: + return json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1, + "model": "flow2api", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + } + ], + } + ) + + +def test_openai_route_resolves_alias_and_returns_non_stream_result(client, fake_handler): + fake_handler.non_stream_chunks = [build_openai_completion("![Generated Image](https://example.com/out.png)")] + + response = client.post( + "/v1/chat/completions", + json={ + "model": "gemini-3.0-pro-image", + "messages": [{"role": "user", "content": "draw a sunset"}], + "generationConfig": { + "imageConfig": { + "aspectRatio": "16:9", + "imageSize": "2K", + } + }, + }, + ) + + assert response.status_code == 200 + assert fake_handler.calls[0]["model"] == "gemini-3.0-pro-image-landscape-2k" + assert response.json()["choices"][0]["message"]["content"].startswith("![Generated Image]") + + +def test_openai_route_returns_handler_error_status(client, fake_handler): + fake_handler.non_stream_chunks = [ + json.dumps( + { + "error": { + "message": "没有可用的Token进行图片生成", + "status_code": 503, + } + } + ) + ] + + response = client.post( + "/v1/chat/completions", + json={ + "model": "gemini-3.0-pro-image", + "messages": [{"role": "user", "content": "draw a tree"}], + }, + ) + + assert response.status_code == 503 + assert response.json()["error"]["message"] == "没有可用的Token进行图片生成" + + +def test_flexible_auth_accepts_x_goog_api_key(monkeypatch): + monkeypatch.setattr(AuthManager, "verify_api_key", staticmethod(lambda api_key: api_key == "secret")) + + assert asyncio.run( + verify_api_key_flexible( + credentials=None, + x_goog_api_key="secret", + key=None, + ) + ) == "secret" diff --git a/tests/test_flow_client.py b/tests/test_flow_client.py new file mode 100644 index 0000000..c7f4139 --- /dev/null +++ b/tests/test_flow_client.py @@ -0,0 +1,56 @@ +import asyncio + +import pytest + +from src.services import flow_client as flow_client_module +from src.services.flow_client import FlowClient + + +def test_create_project_retries_timeout_then_succeeds(monkeypatch): + client = FlowClient(proxy_manager=None) + attempts = [] + sleep_calls = [] + + async def fake_make_request(**kwargs): + attempts.append(kwargs["timeout"]) + if len(attempts) < 3: + raise Exception("Flow API request failed: curl: (28) Connection timed out after 5013 milliseconds") + return { + "result": { + "data": { + "json": { + "result": { + "projectId": "project-123", + } + } + } + } + } + + async def fake_sleep(seconds): + sleep_calls.append(seconds) + + monkeypatch.setattr(client, "_make_request", fake_make_request) + monkeypatch.setattr(flow_client_module.asyncio, "sleep", fake_sleep) + + project_id = asyncio.run(client.create_project("st-token", "Retry Test")) + + assert project_id == "project-123" + assert attempts == [15, 15, 15] + assert sleep_calls == [1, 1] + + +def test_create_project_invalid_response_fails_fast(monkeypatch): + client = FlowClient(proxy_manager=None) + attempts = [] + + async def fake_make_request(**kwargs): + attempts.append(kwargs["timeout"]) + return {"result": {"data": {"json": {"result": {}}}}} + + monkeypatch.setattr(client, "_make_request", fake_make_request) + + with pytest.raises(Exception, match="missing projectId"): + asyncio.run(client.create_project("st-token", "Invalid Response")) + + assert attempts == [15] diff --git a/tests/test_flow_client_control_plane.py b/tests/test_flow_client_control_plane.py new file mode 100644 index 0000000..797ccf9 --- /dev/null +++ b/tests/test_flow_client_control_plane.py @@ -0,0 +1,48 @@ +import asyncio + +from src.services.flow_client import FlowClient + + +def test_control_plane_timeout_is_capped(): + client = FlowClient(None) + + client.timeout = 120 + assert client._get_control_plane_timeout() == 10 + + client.timeout = 8 + assert client._get_control_plane_timeout() == 8 + + client.timeout = 3 + assert client._get_control_plane_timeout() == 5 + + +def test_control_plane_calls_use_short_timeouts(monkeypatch): + client = FlowClient(None) + client.timeout = 120 + calls = [] + + async def fake_make_request(**kwargs): + calls.append({ + "url": kwargs["url"], + "timeout": kwargs.get("timeout"), + }) + url = kwargs["url"] + if url.endswith("/auth/session"): + return {"access_token": "at", "user": {"email": "tester@example.com"}} + if url.endswith("/trpc/project.createProject"): + return {"result": {"data": {"json": {"result": {"projectId": "project-123"}}}}} + if url.endswith("/credits"): + return {"credits": 1000, "userPaygateTier": "PAYGATE_TIER_ONE"} + return {} + + monkeypatch.setattr(client, "_make_request", fake_make_request) + + async def run(): + await client.st_to_at("st") + await client.create_project("st", "demo") + await client.delete_project("st", "project-123") + await client.get_credits("at") + + asyncio.run(run()) + + assert [call["timeout"] for call in calls] == [10, 15, 10, 10] diff --git a/tests/test_gemini_generate_content.py b/tests/test_gemini_generate_content.py new file mode 100644 index 0000000..467ae51 --- /dev/null +++ b/tests/test_gemini_generate_content.py @@ -0,0 +1,149 @@ +import base64 +import json + +from src.api import routes + + +def build_openai_completion(content: str) -> str: + return json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1, + "model": "flow2api", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + } + ], + } + ) + + +def test_generate_content_returns_gemini_response(client, fake_handler, monkeypatch): + fake_handler.non_stream_chunks = [ + build_openai_completion("![Generated Image](https://example.com/generated.png)") + ] + + async def fake_retrieve_image_data(url: str): + return b"\x89PNG\r\n\x1a\nfake" + + monkeypatch.setattr(routes, "retrieve_image_data", fake_retrieve_image_data) + + response = client.post( + "/v1beta/models/gemini-3.0-pro-image:generateContent", + json={ + "contents": [ + { + "role": "user", + "parts": [{"text": "draw a mountain"}], + } + ], + "generationConfig": { + "imageConfig": { + "aspectRatio": "16:9", + "imageSize": "2K", + } + }, + }, + ) + + assert response.status_code == 200 + assert fake_handler.calls[0]["model"] == "gemini-3.0-pro-image-landscape-2k" + body = response.json() + assert body["modelVersion"] == "gemini-3.0-pro-image" + part = body["candidates"][0]["content"]["parts"][0]["inlineData"] + assert part["mimeType"] == "image/png" + assert base64.b64decode(part["data"]).startswith(b"\x89PNG") + + +def test_stream_generate_content_returns_sse_chunks(client, fake_handler, monkeypatch): + fake_handler.stream_chunks = [ + 'data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1,"model":"flow2api","choices":[{"index":0,"delta":{"reasoning_content":"starting generation"},"finish_reason":null}]}\n\n', + 'data: {"id":"chatcmpl-test","object":"chat.completion.chunk","created":1,"model":"flow2api","choices":[{"index":0,"delta":{"content":"![Generated Image](https://example.com/final.png)"},"finish_reason":"stop"}]}\n\n', + ] + + async def fake_retrieve_image_data(url: str): + return b"\x89PNG\r\n\x1a\nstream" + + monkeypatch.setattr(routes, "retrieve_image_data", fake_retrieve_image_data) + + response = client.post( + "/v1beta/models/gemini-3.0-pro-image:streamGenerateContent?alt=sse", + json={ + "contents": [ + { + "role": "user", + "parts": [{"text": "draw a city"}], + } + ] + }, + ) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + + data_lines = [ + line.removeprefix("data: ") + for line in response.text.splitlines() + if line.startswith("data: ") + ] + assert len(data_lines) == 2 + + first_chunk = json.loads(data_lines[0]) + assert first_chunk["modelVersion"] == "gemini-3.0-pro-image" + assert first_chunk["candidates"][0]["content"]["parts"][0]["text"] == "starting generation" + + second_chunk = json.loads(data_lines[1]) + assert second_chunk["modelVersion"] == "gemini-3.0-pro-image" + image_part = second_chunk["candidates"][0]["content"]["parts"][0]["inlineData"] + assert image_part["mimeType"] == "image/png" + assert second_chunk["candidates"][0]["finishReason"] == "STOP" + + +def test_models_generate_content_supports_system_instruction_and_file_data(client, fake_handler): + fake_handler.non_stream_chunks = [ + build_openai_completion("![Generated Image](https://example.com/generated-square.png)") + ] + + reference_image = base64.b64encode(b"\x89PNG\r\n\x1a\nref").decode() + + response = client.post( + "/models/gemini-3.1-flash-image:generateContent", + json={ + "systemInstruction": { + "parts": [{"text": "answer in English"}], + }, + "contents": [ + { + "role": "user", + "parts": [ + {"text": "draw a cat"}, + { + "fileData": { + "fileUri": f"data:image/png;base64,{reference_image}", + "mimeType": "image/png", + } + }, + ], + } + ], + "generationConfig": { + "imageConfig": { + "aspectRatio": "1:1", + "imageSize": "1K", + } + }, + }, + ) + + assert response.status_code == 200 + assert fake_handler.calls[0]["model"] == "gemini-3.1-flash-image-square" + assert response.json()["modelVersion"] == "gemini-3.1-flash-image" + assert fake_handler.calls[0]["prompt"] == "answer in English\n\ndraw a cat" + assert len(fake_handler.calls[0]["images"]) == 1 diff --git a/tests/test_model_resolver.py b/tests/test_model_resolver.py new file mode 100644 index 0000000..d333ddc --- /dev/null +++ b/tests/test_model_resolver.py @@ -0,0 +1,98 @@ +import pytest + +from src.core.model_resolver import resolve_model_name +from src.core.models import ChatCompletionRequest, ChatMessage +from src.services.generation_handler import MODEL_CONFIG + + +def build_request(model: str, **kwargs) -> ChatCompletionRequest: + payload = { + "model": model, + "messages": [ChatMessage(role="user", content="draw a cat")], + } + payload.update(kwargs) + return ChatCompletionRequest(**payload) + + +def test_image_alias_resolves_with_official_generation_config(): + request = build_request( + "gemini-3.0-pro-image", + generationConfig={ + "imageConfig": { + "aspectRatio": "16:9", + "imageSize": "2K", + } + }, + ) + + assert ( + resolve_model_name(request.model, request, MODEL_CONFIG) + == "gemini-3.0-pro-image-landscape-2k" + ) + + +def test_image_alias_treats_1k_as_default_size(): + request = build_request( + "gemini-3.1-flash-image", + generationConfig={ + "imageConfig": { + "aspectRatio": "1:1", + "imageSize": "1K", + } + }, + ) + + assert ( + resolve_model_name(request.model, request, MODEL_CONFIG) + == "gemini-3.1-flash-image-square" + ) + + +def test_generation_config_can_come_from_extra_body(): + request = build_request( + "gemini-3.1-flash-image", + extra_body={ + "generationConfig": { + "imageConfig": { + "aspectRatio": "4:3", + "imageSize": "4K", + } + } + }, + ) + + assert ( + resolve_model_name(request.model, request, MODEL_CONFIG) + == "gemini-3.1-flash-image-four-three-4k" + ) + + +@pytest.mark.parametrize( + ("alias", "expected"), + [ + ("veo_3_1_t2v_fast_ultra", "veo_3_1_t2v_fast_portrait_ultra"), + ( + "veo_3_1_t2v_fast_ultra_relaxed", + "veo_3_1_t2v_fast_portrait_ultra_relaxed", + ), + ("veo_3_1_i2v_s_fast_fl", "veo_3_1_i2v_s_fast_portrait_fl"), + ("veo_3_1_i2v_s_fast_ultra_fl", "veo_3_1_i2v_s_fast_portrait_ultra_fl"), + ( + "veo_3_1_i2v_s_fast_ultra_relaxed", + "veo_3_1_i2v_s_fast_portrait_ultra_relaxed", + ), + ("veo_3_1_r2v_fast", "veo_3_1_r2v_fast_portrait"), + ("veo_3_1_r2v_fast_ultra", "veo_3_1_r2v_fast_portrait_ultra"), + ( + "veo_3_1_r2v_fast_ultra_relaxed", + "veo_3_1_r2v_fast_portrait_ultra_relaxed", + ), + ], +) +def test_conflicting_video_aliases_resolve_to_portrait(alias, expected): + request = build_request( + alias, + generationConfig={"imageConfig": {"aspectRatio": "9:16"}}, + ) + + assert resolve_model_name(request.model, request, MODEL_CONFIG) == expected