From 72a4f0909e0e77fd2d05d473d17280ecef26efe2 Mon Sep 17 00:00:00 2001 From: TheSmallHanCat <123456789+TheSmallHanCat@users.noreply.github.com> Date: Wed, 25 Feb 2026 12:46:40 +0800 Subject: [PATCH] feat: add semantic probe and token refresh flow --- config/setting.toml | 7 + config/setting_example.toml | 7 + src/api/admin.py | 44 ++++ src/api/routes.py | 159 ++++++++++- src/core/config.py | 56 ++++ src/core/database.py | 202 +++++++++++--- src/core/models.py | 26 ++ src/main.py | 9 + src/services/flow_client.py | 143 +++++++--- src/services/generation_handler.py | 408 +++++++++++++++++++++++++---- src/services/semantic_probe.py | 130 +++++++++ src/services/token_manager.py | 2 +- 12 files changed, 1066 insertions(+), 127 deletions(-) create mode 100644 src/services/semantic_probe.py diff --git a/config/setting.toml b/config/setting.toml index 051cfd3..59a0bba 100644 --- a/config/setting.toml +++ b/config/setting.toml @@ -40,3 +40,10 @@ base_url = "" # 缓存文件访问的基础URL, 留空则使用服务器地址 captcha_method = "browser" # 打码方式: yescaptcha 或 browser yescaptcha_api_key = "" # YesCaptcha API密钥 yescaptcha_base_url = "https://api.yescaptcha.com" + +[semantic_probe] +enabled = false +api_url = "https://api.openai.com/v1/chat/completions" +api_key = "" +model = "gpt-4o-mini" +timeout = 15 diff --git a/config/setting_example.toml b/config/setting_example.toml index 051cfd3..59a0bba 100644 --- a/config/setting_example.toml +++ b/config/setting_example.toml @@ -40,3 +40,10 @@ base_url = "" # 缓存文件访问的基础URL, 留空则使用服务器地址 captcha_method = "browser" # 打码方式: yescaptcha 或 browser yescaptcha_api_key = "" # YesCaptcha API密钥 yescaptcha_base_url = "https://api.yescaptcha.com" + +[semantic_probe] +enabled = false +api_url = "https://api.openai.com/v1/chat/completions" +api_key = "" +model = "gpt-4o-mini" +timeout = 15 diff --git a/src/api/admin.py b/src/api/admin.py index fe9db0a..3d99b1a 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -68,6 +68,14 @@ class GenerationConfigRequest(BaseModel): video_timeout: int +class SemanticProbeConfigRequest(BaseModel): + enabled: Optional[bool] = None + api_url: Optional[str] = None + api_key: Optional[str] = None + model: Optional[str] = None + timeout: Optional[int] = None + + class ChangePasswordRequest(BaseModel): username: Optional[str] = None old_password: str @@ -584,6 +592,42 @@ async def update_generation_config( return {"success": True, "message": "生成配置更新成功"} +@router.get("/api/semantic-probe/config") +async def get_semantic_probe_config(token: str = Depends(verify_admin_token)): + """Get semantic probe configuration""" + probe_config = await db.get_semantic_probe_config() + return { + "success": True, + "config": { + "enabled": probe_config.enabled, + "api_url": probe_config.api_url, + "api_key": probe_config.api_key, + "model": probe_config.model, + "timeout": probe_config.timeout + } + } + + +@router.post("/api/semantic-probe/config") +async def update_semantic_probe_config( + request: SemanticProbeConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update semantic probe configuration""" + await db.update_semantic_probe_config( + enabled=request.enabled, + api_url=request.api_url, + api_key=request.api_key, + model=request.model, + timeout=request.timeout + ) + + # 🔥 Hot reload: sync database config to memory + await db.reload_config_to_memory() + + return {"success": True, "message": "语意探查配置更新成功"} + + # ========== System Info ========== @router.get("/api/system/info") diff --git a/src/api/routes.py b/src/api/routes.py index ec79e11..a3a3cb0 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -1,22 +1,129 @@ """API routes - OpenAI compatible endpoints""" from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse, JSONResponse -from typing import List, Optional +from typing import List, Optional, Dict, Any, Tuple import base64 import re import json -import time from urllib.parse import urlparse from curl_cffi.requests import AsyncSession from ..core.auth import verify_api_key_header from ..core.models import ChatCompletionRequest from ..services.generation_handler import GenerationHandler, MODEL_CONFIG +from ..services.semantic_probe import SemanticProbeService from ..core.logger import debug_logger router = APIRouter() - # Dependency injection will be set up in main.py generation_handler: GenerationHandler = None +semantic_probe_service = SemanticProbeService() + + +def _normalize_quality_from_image_quality(image_quality: Optional[str]) -> Optional[str]: + value = (image_quality or "").strip().lower().replace("_", "-") + mapping = { + "standard": "standard", + "normal": "standard", + "default": "standard", + "hd": "ultra", + "high": "ultra", + "ultra": "ultra", + "ultra-relaxed": "ultra_relaxed", + "relaxed": "ultra_relaxed" + } + return mapping.get(value) + + +def _parse_size(size: Optional[str]) -> Tuple[Optional[int], Optional[int]]: + if not size: + return None, None + match = re.match(r"^\s*(\d{2,5})\s*[xX]\s*(\d{2,5})\s*$", size) + if not match: + return None, None + return int(match.group(1)), int(match.group(2)) + + +def _infer_aspect_ratio(width: Optional[int], height: Optional[int]) -> Optional[str]: + if not width or not height or width <= 0 or height <= 0: + return None + + ratio = width / height + if abs(ratio - 1.0) <= 0.08: + return "square" + if abs(ratio - (4 / 3)) <= 0.08: + return "four-three" + if abs(ratio - (3 / 4)) <= 0.08: + return "three-four" + return "landscape" if width > height else "portrait" + + +def _infer_resolution(width: Optional[int], height: Optional[int]) -> Optional[str]: + if not width or not height or width <= 0 or height <= 0: + return None + max_side = max(width, height) + if max_side >= 3000: + return "4k" + if max_side >= 1800: + return "2k" + if max_side >= 1000: + return "1080p" + return None + + +async def _resolve_generation_params( + request: ChatCompletionRequest, + prompt: str, + has_images: bool +) -> Dict[str, Any]: + # 1) 显式参数(最高优先级) + width = request.width + height = request.height + + size_w, size_h = _parse_size(request.size) + if width is None and size_w is not None: + width = size_w + if height is None and size_h is not None: + height = size_h + + derived_aspect = _infer_aspect_ratio(width, height) + derived_resolution = _infer_resolution(width, height) + mapped_quality = _normalize_quality_from_image_quality(request.image_quality) + + current = { + "aspect_ratio": request.aspect_ratio or derived_aspect, + "resolution": request.resolution or derived_resolution, + "quality": request.quality or mapped_quality, + "video_type": request.video_type + } + + # 2) 语意探查仅补全缺失字段 + probed: Dict[str, Optional[str]] = {} + try: + probed = await semantic_probe_service.infer( + prompt=prompt, + has_images=has_images, + current=current + ) + except Exception as e: + debug_logger.log_warning(f"[SEMANTIC_PROBE] 路由推断异常,继续本地逻辑: {str(e)}") + + final_aspect = request.aspect_ratio or derived_aspect or probed.get("aspect_ratio") + final_resolution = request.resolution or derived_resolution or probed.get("resolution") + final_quality = request.quality or mapped_quality or probed.get("quality") + final_video_type = request.video_type or probed.get("video_type") + + image_count = request.n if isinstance(request.n, int) and request.n > 0 else 1 + image_count = min(image_count, 4) + + return { + "aspect_ratio": final_aspect, + "resolution": final_resolution, + "quality": final_quality, + "video_type": final_video_type, + "image_count": image_count, + "image_style": request.style, + "image_seed": request.seed + } def set_generation_handler(handler: GenerationHandler): @@ -142,8 +249,29 @@ async def create_chat_completion( image_bytes = base64.b64decode(image_base64) images.append(image_bytes) - # 自动参考图:仅对图片模型生效 - model_config = MODEL_CONFIG.get(request.model) + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + + # 参数映射与语意探查(显式参数优先,探查仅补缺) + effective_params = await _resolve_generation_params( + request=request, + prompt=prompt, + has_images=(len(images) > 0) + ) + + # 自动参考图:仅对图片模型生效(支持通用模型推断) + model_config = None + try: + _, model_config = generation_handler.resolve_model( + model=request.model, + images=images if images else None, + aspect_ratio=effective_params["aspect_ratio"], + resolution=effective_params["resolution"], + quality=effective_params["quality"], + video_type=effective_params["video_type"] + ) + except Exception: + model_config = MODEL_CONFIG.get(request.model) if model_config and model_config["type"] == "image" and len(request.messages) > 1: debug_logger.log_info(f"[CONTEXT] 开始查找历史参考图,消息数量: {len(request.messages)}") @@ -170,9 +298,6 @@ async def create_chat_completion( debug_logger.log_error(f"[CONTEXT] 处理参考图时出错: {str(e)}") # 继续尝试下一个图片 - if not prompt: - raise HTTPException(status_code=400, detail="Prompt cannot be empty") - # Call generation handler if request.stream: # Streaming response @@ -181,7 +306,14 @@ async def generate(): model=request.model, prompt=prompt, images=images if images else None, - stream=True + stream=True, + aspect_ratio=effective_params["aspect_ratio"], + resolution=effective_params["resolution"], + quality=effective_params["quality"], + video_type=effective_params["video_type"], + image_count=effective_params["image_count"], + image_style=effective_params["image_style"], + image_seed=effective_params["image_seed"] ): yield chunk @@ -204,7 +336,14 @@ async def generate(): model=request.model, prompt=prompt, images=images if images else None, - stream=False + stream=False, + aspect_ratio=effective_params["aspect_ratio"], + resolution=effective_params["resolution"], + quality=effective_params["quality"], + video_type=effective_params["video_type"], + image_count=effective_params["image_count"], + image_style=effective_params["image_style"], + image_seed=effective_params["image_seed"] ): result = chunk diff --git a/src/core/config.py b/src/core/config.py index 061f238..60056cb 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -156,6 +156,62 @@ def set_upsample_timeout(self, timeout: int): self._config["generation"] = {} self._config["generation"]["upsample_timeout"] = timeout + # Semantic probe configuration + @property + def semantic_probe_enabled(self) -> bool: + """Get semantic probe enabled status""" + return self._config.get("semantic_probe", {}).get("enabled", False) + + @property + def semantic_probe_api_url(self) -> str: + """Get semantic probe Chat API URL""" + return self._config.get("semantic_probe", {}).get("api_url", "") + + @property + def semantic_probe_api_key(self) -> str: + """Get semantic probe Chat API key""" + return self._config.get("semantic_probe", {}).get("api_key", "") + + @property + def semantic_probe_model(self) -> str: + """Get semantic probe Chat model name""" + return self._config.get("semantic_probe", {}).get("model", "") + + @property + def semantic_probe_timeout(self) -> int: + """Get semantic probe request timeout in seconds""" + return int(self._config.get("semantic_probe", {}).get("timeout", 15)) + + def set_semantic_probe_enabled(self, enabled: bool): + """Set semantic probe enabled status""" + if "semantic_probe" not in self._config: + self._config["semantic_probe"] = {} + self._config["semantic_probe"]["enabled"] = bool(enabled) + + def set_semantic_probe_api_url(self, api_url: str): + """Set semantic probe Chat API URL""" + if "semantic_probe" not in self._config: + self._config["semantic_probe"] = {} + self._config["semantic_probe"]["api_url"] = api_url or "" + + def set_semantic_probe_api_key(self, api_key: str): + """Set semantic probe Chat API key""" + if "semantic_probe" not in self._config: + self._config["semantic_probe"] = {} + self._config["semantic_probe"]["api_key"] = api_key or "" + + def set_semantic_probe_model(self, model: str): + """Set semantic probe model name""" + if "semantic_probe" not in self._config: + self._config["semantic_probe"] = {} + self._config["semantic_probe"]["model"] = model or "" + + def set_semantic_probe_timeout(self, timeout: int): + """Set semantic probe timeout in seconds""" + if "semantic_probe" not in self._config: + self._config["semantic_probe"] = {} + self._config["semantic_probe"]["timeout"] = int(timeout) + # Cache configuration @property def cache_enabled(self) -> bool: diff --git a/src/core/database.py b/src/core/database.py index f819bf9..30ce53f 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Optional, List from pathlib import Path -from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, GenerationConfig, CacheConfig, Project, CaptchaConfig, PluginConfig +from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, GenerationConfig, CacheConfig, Project, CaptchaConfig, PluginConfig, SemanticProbeConfig class Database: @@ -40,6 +40,24 @@ async def _column_exists(self, db, table_name: str, column_name: str) -> bool: except: return False + async def _add_missing_columns_atomic(self, db, table_name: str, columns_to_add: list[tuple[str, str]]): + """Add missing columns atomically for a table. + + If any ALTER TABLE fails, rollback the whole batch for this table. + """ + savepoint = f"sp_add_cols_{table_name}" + await db.execute(f"SAVEPOINT {savepoint}") + try: + for col_name, col_type in columns_to_add: + if not await self._column_exists(db, table_name, col_name): + await db.execute(f"ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type}") + print(f" ✓ Added column '{col_name}' to {table_name} table") + await db.execute(f"RELEASE SAVEPOINT {savepoint}") + except Exception as e: + await db.execute(f"ROLLBACK TO SAVEPOINT {savepoint}") + await db.execute(f"RELEASE SAVEPOINT {savepoint}") + raise RuntimeError(f"Failed to migrate columns for table '{table_name}': {e}") + async def _ensure_config_rows(self, db, config_dict: dict = None): """Ensure all config tables have their default rows @@ -176,6 +194,29 @@ async def _ensure_config_rows(self, db, config_dict: dict = None): VALUES (1, '', 1) """) + # Ensure semantic_probe_config has a row + cursor = await db.execute("SELECT COUNT(*) FROM semantic_probe_config") + count = await cursor.fetchone() + if count[0] == 0: + enabled = False + api_url = "https://api.openai.com/v1/chat/completions" + api_key = "" + model = "gpt-4o-mini" + timeout = 15 + + if config_dict: + semantic_probe = config_dict.get("semantic_probe", {}) + enabled = semantic_probe.get("enabled", False) + api_url = semantic_probe.get("api_url", api_url) + api_key = semantic_probe.get("api_key", "") + model = semantic_probe.get("model", model) + timeout = int(semantic_probe.get("timeout", timeout)) + + await db.execute(""" + INSERT INTO semantic_probe_config (id, enabled, api_url, api_key, model, timeout) + VALUES (1, ?, ?, ?, ?, ?) + """, (enabled, api_url, api_key, model, timeout)) + async def check_and_migrate_db(self, config_dict: dict = None): """Check database integrity and perform migrations if needed @@ -244,6 +285,22 @@ async def check_and_migrate_db(self, config_dict: dict = None): ) """) + # Check and create semantic_probe_config table if missing + if not await self._table_exists(db, "semantic_probe_config"): + print(" ✓ Creating missing table: semantic_probe_config") + await db.execute(""" + CREATE TABLE semantic_probe_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + enabled BOOLEAN DEFAULT 0, + api_url TEXT DEFAULT 'https://api.openai.com/v1/chat/completions', + api_key TEXT DEFAULT '', + model TEXT DEFAULT 'gpt-4o-mini', + timeout INTEGER DEFAULT 15, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + # ========== Step 2: Add missing columns to existing tables ========== # Check and add missing columns to tokens table if await self._table_exists(db, "tokens"): @@ -262,22 +319,15 @@ async def check_and_migrate_db(self, config_dict: dict = None): ("banned_at", "TIMESTAMP"), # 禁用时间 ] - for col_name, col_type in columns_to_add: - if not await self._column_exists(db, "tokens", col_name): - try: - await db.execute(f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}") - print(f" ✓ Added column '{col_name}' to tokens table") - except Exception as e: - print(f" ✗ Failed to add column '{col_name}': {e}") + await self._add_missing_columns_atomic(db, "tokens", columns_to_add) # Check and add missing columns to admin_config table if await self._table_exists(db, "admin_config"): - if not await self._column_exists(db, "admin_config", "error_ban_threshold"): - try: - await db.execute("ALTER TABLE admin_config ADD COLUMN error_ban_threshold INTEGER DEFAULT 3") - print(" ✓ Added column 'error_ban_threshold' to admin_config table") - except Exception as e: - print(f" ✗ Failed to add column 'error_ban_threshold': {e}") + await self._add_missing_columns_atomic( + db, + "admin_config", + [("error_ban_threshold", "INTEGER DEFAULT 3")] + ) # Check and add missing columns to captcha_config table if await self._table_exists(db, "captcha_config"): @@ -293,13 +343,7 @@ async def check_and_migrate_db(self, config_dict: dict = None): ("browser_count", "INTEGER DEFAULT 1"), ] - for col_name, col_type in captcha_columns_to_add: - if not await self._column_exists(db, "captcha_config", col_name): - try: - await db.execute(f"ALTER TABLE captcha_config ADD COLUMN {col_name} {col_type}") - print(f" ✓ Added column '{col_name}' to captcha_config table") - except Exception as e: - print(f" ✗ Failed to add column '{col_name}': {e}") + await self._add_missing_columns_atomic(db, "captcha_config", captcha_columns_to_add) # Check and add missing columns to token_stats table if await self._table_exists(db, "token_stats"): @@ -311,13 +355,7 @@ async def check_and_migrate_db(self, config_dict: dict = None): ("consecutive_error_count", "INTEGER DEFAULT 0"), # 🆕 连续错误计数 ] - for col_name, col_type in stats_columns_to_add: - if not await self._column_exists(db, "token_stats", col_name): - try: - await db.execute(f"ALTER TABLE token_stats ADD COLUMN {col_name} {col_type}") - print(f" ✓ Added column '{col_name}' to token_stats table") - except Exception as e: - print(f" ✗ Failed to add column '{col_name}': {e}") + await self._add_missing_columns_atomic(db, "token_stats", stats_columns_to_add) # Check and add missing columns to plugin_config table if await self._table_exists(db, "plugin_config"): @@ -325,13 +363,7 @@ async def check_and_migrate_db(self, config_dict: dict = None): ("auto_enable_on_update", "BOOLEAN DEFAULT 1"), # 默认开启 ] - for col_name, col_type in plugin_columns_to_add: - if not await self._column_exists(db, "plugin_config", col_name): - try: - await db.execute(f"ALTER TABLE plugin_config ADD COLUMN {col_name} {col_type}") - print(f" ✓ Added column '{col_name}' to plugin_config table") - except Exception as e: - print(f" ✗ Failed to add column '{col_name}': {e}") + await self._add_missing_columns_atomic(db, "plugin_config", plugin_columns_to_add) # ========== Step 3: Ensure all config tables have default rows ========== # Note: This will NOT overwrite existing config rows @@ -531,6 +563,20 @@ async def init_db(self): ) """) + # Semantic probe config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS semantic_probe_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + enabled BOOLEAN DEFAULT 0, + api_url TEXT DEFAULT 'https://api.openai.com/v1/chat/completions', + api_key TEXT DEFAULT '', + model TEXT DEFAULT 'gpt-4o-mini', + timeout INTEGER DEFAULT 15, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + # Create indexes await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)") await db.execute("CREATE INDEX IF NOT EXISTS idx_token_st ON tokens(st)") @@ -691,6 +737,17 @@ async def delete_token(self, token_id: int): await db.execute("DELETE FROM tokens WHERE id = ?", (token_id,)) await db.commit() + async def increment_token_use_count(self, token_id: int): + """Atomically increment token use count and update last_used_at""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE tokens + SET use_count = COALESCE(use_count, 0) + 1, + last_used_at = CURRENT_TIMESTAMP + WHERE id = ? + """, (token_id,)) + await db.commit() + # Project operations async def add_project(self, project: Project) -> int: """Add a new project""" @@ -777,6 +834,13 @@ async def update_task(self, task_id: str, **kwargs): await db.commit() # Token stats operations (kept for compatibility, now delegates to specific methods) + async def _ensure_token_stats_row(self, db, token_id: int): + """Ensure token_stats row exists for token_id before any counter update""" + cursor = await db.execute("SELECT 1 FROM token_stats WHERE token_id = ?", (token_id,)) + row = await cursor.fetchone() + if not row: + await db.execute("INSERT INTO token_stats (token_id) VALUES (?)", (token_id,)) + async def increment_token_stats(self, token_id: int, stat_type: str): """Increment token statistics (delegates to specific methods)""" if stat_type == "image": @@ -801,6 +865,8 @@ async def increment_image_count(self, token_id: int): from datetime import date async with aiosqlite.connect(self.db_path) as db: today = str(date.today()) + await self._ensure_token_stats_row(db, token_id) + # Get current stats cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,)) row = await cursor.fetchone() @@ -830,6 +896,8 @@ async def increment_video_count(self, token_id: int): from datetime import date async with aiosqlite.connect(self.db_path) as db: today = str(date.today()) + await self._ensure_token_stats_row(db, token_id) + # Get current stats cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,)) row = await cursor.fetchone() @@ -865,6 +933,8 @@ async def increment_error_count(self, token_id: int): from datetime import date async with aiosqlite.connect(self.db_path) as db: today = str(date.today()) + await self._ensure_token_stats_row(db, token_id) + # Get current stats cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,)) row = await cursor.fetchone() @@ -903,6 +973,7 @@ async def reset_error_count(self, token_id: int): Note: error_count (total historical errors) is NEVER reset """ async with aiosqlite.connect(self.db_path) as db: + await self._ensure_token_stats_row(db, token_id) await db.execute(""" UPDATE token_stats SET consecutive_error_count = 0 WHERE token_id = ? """, (token_id,)) @@ -1109,6 +1180,15 @@ async def reload_config_to_memory(self): config.set_capsolver_api_key(captcha_config.capsolver_api_key) config.set_capsolver_base_url(captcha_config.capsolver_base_url) + # Reload semantic probe config + semantic_probe_config = await self.get_semantic_probe_config() + if semantic_probe_config: + config.set_semantic_probe_enabled(semantic_probe_config.enabled) + config.set_semantic_probe_api_url(semantic_probe_config.api_url) + config.set_semantic_probe_api_key(semantic_probe_config.api_key) + config.set_semantic_probe_model(semantic_probe_config.model) + config.set_semantic_probe_timeout(semantic_probe_config.timeout) + # Cache config operations async def get_cache_config(self) -> CacheConfig: """Get cache configuration""" @@ -1324,3 +1404,55 @@ async def update_plugin_config(self, connection_token: str, auto_enable_on_updat """, (connection_token, auto_enable_on_update)) await db.commit() + + # Semantic probe config operations + async def get_semantic_probe_config(self) -> SemanticProbeConfig: + """Get semantic probe configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM semantic_probe_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return SemanticProbeConfig(**dict(row)) + return SemanticProbeConfig() + + async def update_semantic_probe_config( + self, + enabled: bool = None, + api_url: str = None, + api_key: str = None, + model: str = None, + timeout: int = None + ): + """Update semantic probe configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM semantic_probe_config WHERE id = 1") + row = await cursor.fetchone() + + if row: + current = dict(row) + new_enabled = enabled if enabled is not None else current.get("enabled", False) + new_api_url = api_url if api_url is not None else current.get("api_url", "https://api.openai.com/v1/chat/completions") + new_api_key = api_key if api_key is not None else current.get("api_key", "") + new_model = model if model is not None else current.get("model", "gpt-4o-mini") + new_timeout = int(timeout) if timeout is not None else int(current.get("timeout", 15)) + + await db.execute(""" + UPDATE semantic_probe_config + SET enabled = ?, api_url = ?, api_key = ?, model = ?, timeout = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (new_enabled, new_api_url, new_api_key, new_model, new_timeout)) + else: + new_enabled = enabled if enabled is not None else False + new_api_url = api_url if api_url is not None else "https://api.openai.com/v1/chat/completions" + new_api_key = api_key if api_key is not None else "" + new_model = model if model is not None else "gpt-4o-mini" + new_timeout = int(timeout) if timeout is not None else 15 + + await db.execute(""" + INSERT INTO semantic_probe_config (id, enabled, api_url, api_key, model, timeout) + VALUES (1, ?, ?, ?, ?, ?) + """, (new_enabled, new_api_url, new_api_key, new_model, new_timeout)) + + await db.commit() diff --git a/src/core/models.py b/src/core/models.py index b005026..5e52582 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -174,6 +174,18 @@ class PluginConfig(BaseModel): updated_at: Optional[datetime] = None +class SemanticProbeConfig(BaseModel): + """Semantic probe configuration""" + id: int = 1 + enabled: bool = False + api_url: str = "https://api.openai.com/v1/chat/completions" + api_key: str = "" + model: str = "gpt-4o-mini" + timeout: int = 15 + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + # OpenAI Compatible Request Models class ChatMessage(BaseModel): """Chat message""" @@ -191,3 +203,17 @@ class ChatCompletionRequest(BaseModel): # Flow2API specific parameters image: Optional[str] = None # Base64 encoded image (deprecated, use messages) video: Optional[str] = None # Base64 encoded video (deprecated) + # 通用模型推断参数(可选) + aspect_ratio: Optional[str] = None # landscape/portrait/square/four-three/three-four + resolution: Optional[str] = None # 2k/4k/1080p + quality: Optional[str] = None # standard/ultra/ultra_relaxed + video_type: Optional[str] = None # t2v/i2v/r2v + + # 图片兼容参数(OpenAI 风格) + size: Optional[str] = None # e.g. "1024x1024" / "1792x1024" / "1024x1792" + width: Optional[int] = None + height: Optional[int] = None + n: Optional[int] = 1 + image_quality: Optional[str] = None # standard/hd/high/ultra + style: Optional[str] = None # vivid/natural/... + seed: Optional[int] = None diff --git a/src/main.py b/src/main.py index 991f39d..87603e0 100644 --- a/src/main.py +++ b/src/main.py @@ -79,6 +79,14 @@ async def lifespan(app: FastAPI): config.set_capsolver_api_key(captcha_config.capsolver_api_key) config.set_capsolver_base_url(captcha_config.capsolver_base_url) + # Load semantic probe configuration from database + semantic_probe_config = await db.get_semantic_probe_config() + config.set_semantic_probe_enabled(semantic_probe_config.enabled) + config.set_semantic_probe_api_url(semantic_probe_config.api_url) + config.set_semantic_probe_api_key(semantic_probe_config.api_key) + config.set_semantic_probe_model(semantic_probe_config.model) + config.set_semantic_probe_timeout(semantic_probe_config.timeout) + # Initialize browser captcha service if needed browser_service = None if captcha_config.captcha_method == "personal": @@ -133,6 +141,7 @@ async def auto_unban_task(): print(f"✓ Cache: {'Enabled' if config.cache_enabled else 'Disabled'} (timeout: {config.cache_timeout}s)") print(f"✓ File cache cleanup task started") print(f"✓ 429 auto-unban task started (runs every hour)") + print(f"✓ Semantic probe: {'Enabled' if config.semantic_probe_enabled else 'Disabled'}") print(f"✓ Server running on http://{config.server_host}:{config.server_port}") print("=" * 60) diff --git a/src/services/flow_client.py b/src/services/flow_client.py index cfdbc68..ae458cc 100644 --- a/src/services/flow_client.py +++ b/src/services/flow_client.py @@ -470,7 +470,10 @@ async def generate_image( prompt: str, model_name: str, aspect_ratio: str, - image_inputs: Optional[List[Dict]] = None + image_inputs: Optional[List[Dict]] = None, + image_count: int = 1, + image_style: Optional[str] = None, + image_seed: Optional[int] = None ) -> dict: """生成图片(同步返回) @@ -481,6 +484,9 @@ async def generate_image( model_name: GEM_PIX, GEM_PIX_2 或 IMAGEN_3_5 aspect_ratio: 图片宽高比 image_inputs: 参考图片列表(图生图时使用) + image_count: 生成数量(兼容参数 n) + image_style: 图像风格(兼容参数 style) + image_seed: 随机种子(兼容参数 seed) Returns: { @@ -504,7 +510,12 @@ async def generate_image( # 每次重试都重新获取 reCAPTCHA token recaptcha_token, browser_id = await self._get_recaptcha_token(project_id, action="IMAGE_GENERATION") if not recaptcha_token: - raise Exception("Failed to obtain reCAPTCHA token") + last_error = Exception("Failed to obtain reCAPTCHA token") + if retry_attempt < max_retries - 1: + debug_logger.log_warning(f"[IMAGE] 获取 reCAPTCHA token 失败,准备重试 ({retry_attempt + 2}/{max_retries})...") + await asyncio.sleep(1) + continue + raise last_error session_id = self._generate_session_id() # 构建请求 - clientContext 只在外层,requests 内不重复 @@ -518,14 +529,25 @@ async def generate_image( "tool": "PINHOLE" } + normalized_count = image_count if isinstance(image_count, int) and image_count > 0 else 1 + normalized_count = min(normalized_count, 4) + request_data = { - "seed": random.randint(1, 99999), + "seed": image_seed if isinstance(image_seed, int) else random.randint(1, 99999), "imageModelName": model_name, "imageAspectRatio": aspect_ratio, "prompt": prompt, "imageInputs": image_inputs or [] } + if normalized_count > 1: + # OpenAI 兼容参数 n -> 上游批量数量字段 + request_data["sampleCount"] = normalized_count + + if image_style: + # 兼容参数 style(上游若不识别会忽略) + request_data["style"] = image_style + json_data = { "clientContext": client_context, "requests": [request_data] @@ -575,38 +597,62 @@ async def upsample_image( """ url = f"{self.api_base_url}/flow/upsampleImage" - # 获取 reCAPTCHA token - 使用 IMAGE_GENERATION action - recaptcha_token, _ = await self._get_recaptcha_token(project_id, action="IMAGE_GENERATION") - if not recaptcha_token: - raise Exception("Failed to obtain reCAPTCHA token") - session_id = self._generate_session_id() + # 403/reCAPTCHA 重试逻辑 - 最多重试3次 + max_retries = 3 + last_error = None - json_data = { - "mediaId": media_id, - "targetResolution": target_resolution, - "clientContext": { - "recaptchaContext": { - "token": recaptcha_token, - "applicationType": "RECAPTCHA_APPLICATION_TYPE_WEB" - }, - "sessionId": session_id, - "projectId": project_id, - "tool": "PINHOLE" + for retry_attempt in range(max_retries): + # 每次重试都重新获取 reCAPTCHA token + recaptcha_token, browser_id = await self._get_recaptcha_token(project_id, action="IMAGE_GENERATION") + if not recaptcha_token: + last_error = Exception("Failed to obtain reCAPTCHA token") + if retry_attempt < max_retries - 1: + debug_logger.log_warning(f"[UPSAMPLE IMAGE] 获取 reCAPTCHA token 失败,准备重试 ({retry_attempt + 2}/{max_retries})...") + await asyncio.sleep(1) + continue + raise last_error + + session_id = self._generate_session_id() + + json_data = { + "mediaId": media_id, + "targetResolution": target_resolution, + "clientContext": { + "recaptchaContext": { + "token": recaptcha_token, + "applicationType": "RECAPTCHA_APPLICATION_TYPE_WEB" + }, + "sessionId": session_id, + "projectId": project_id, + "tool": "PINHOLE" + } } - } - # 4K/2K 放大使用专用超时,因为返回的 base64 数据量很大 - result = await self._make_request( - method="POST", - url=url, - json_data=json_data, - use_at=True, - at_token=at, - timeout=config.upsample_timeout - ) + try: + # 4K/2K 放大使用专用超时,因为返回的 base64 数据量很大 + result = await self._make_request( + method="POST", + url=url, + json_data=json_data, + use_at=True, + at_token=at, + timeout=config.upsample_timeout + ) + # 返回 base64 编码的图片 + return result.get("encodedImage", "") + except Exception as e: + error_str = str(e) + last_error = e + retry_reason = self._get_retry_reason(error_str) + if retry_reason and retry_attempt < max_retries - 1: + debug_logger.log_warning(f"[UPSAMPLE IMAGE] 放大遇到{retry_reason},正在重新获取验证码重试 ({retry_attempt + 2}/{max_retries})...") + await self._notify_browser_captcha_error(browser_id) + await asyncio.sleep(1) + continue + raise e - # 返回 base64 编码的图片 - return result.get("encodedImage", "") + # 所有重试都失败 + raise last_error # ========== 视频生成 (使用AT) - 异步返回 ========== @@ -649,7 +695,12 @@ async def generate_video_text( # 每次重试都重新获取 reCAPTCHA token - 视频使用 VIDEO_GENERATION action recaptcha_token, browser_id = await self._get_recaptcha_token(project_id, action="VIDEO_GENERATION") if not recaptcha_token: - raise Exception("Failed to obtain reCAPTCHA token") + last_error = Exception("Failed to obtain reCAPTCHA token") + if retry_attempt < max_retries - 1: + debug_logger.log_warning(f"[VIDEO T2V] 获取 reCAPTCHA token 失败,准备重试 ({retry_attempt + 2}/{max_retries})...") + await asyncio.sleep(1) + continue + raise last_error session_id = self._generate_session_id() scene_id = str(uuid.uuid4()) @@ -735,7 +786,12 @@ async def generate_video_reference_images( # 每次重试都重新获取 reCAPTCHA token - 视频使用 VIDEO_GENERATION action recaptcha_token, browser_id = await self._get_recaptcha_token(project_id, action="VIDEO_GENERATION") if not recaptcha_token: - raise Exception("Failed to obtain reCAPTCHA token") + last_error = Exception("Failed to obtain reCAPTCHA token") + if retry_attempt < max_retries - 1: + debug_logger.log_warning(f"[VIDEO R2V] 获取 reCAPTCHA token 失败,准备重试 ({retry_attempt + 2}/{max_retries})...") + await asyncio.sleep(1) + continue + raise last_error session_id = self._generate_session_id() scene_id = str(uuid.uuid4()) @@ -824,7 +880,12 @@ async def generate_video_start_end( # 每次重试都重新获取 reCAPTCHA token - 视频使用 VIDEO_GENERATION action recaptcha_token, browser_id = await self._get_recaptcha_token(project_id, action="VIDEO_GENERATION") if not recaptcha_token: - raise Exception("Failed to obtain reCAPTCHA token") + last_error = Exception("Failed to obtain reCAPTCHA token") + if retry_attempt < max_retries - 1: + debug_logger.log_warning(f"[VIDEO I2V] 获取 reCAPTCHA token 失败,准备重试 ({retry_attempt + 2}/{max_retries})...") + await asyncio.sleep(1) + continue + raise last_error session_id = self._generate_session_id() scene_id = str(uuid.uuid4()) @@ -916,7 +977,12 @@ async def generate_video_start_image( # 每次重试都重新获取 reCAPTCHA token - 视频使用 VIDEO_GENERATION action recaptcha_token, browser_id = await self._get_recaptcha_token(project_id, action="VIDEO_GENERATION") if not recaptcha_token: - raise Exception("Failed to obtain reCAPTCHA token") + last_error = Exception("Failed to obtain reCAPTCHA token") + if retry_attempt < max_retries - 1: + debug_logger.log_warning(f"[VIDEO START] 获取 reCAPTCHA token 失败,准备重试 ({retry_attempt + 2}/{max_retries})...") + await asyncio.sleep(1) + continue + raise last_error session_id = self._generate_session_id() scene_id = str(uuid.uuid4()) @@ -1005,7 +1071,12 @@ async def upsample_video( for retry_attempt in range(max_retries): recaptcha_token, browser_id = await self._get_recaptcha_token(project_id, action="VIDEO_GENERATION") if not recaptcha_token: - raise Exception("Failed to obtain reCAPTCHA token") + last_error = Exception("Failed to obtain reCAPTCHA token") + if retry_attempt < max_retries - 1: + debug_logger.log_warning(f"[VIDEO UPSAMPLE] 获取 reCAPTCHA token 失败,准备重试 ({retry_attempt + 2}/{max_retries})...") + await asyncio.sleep(1) + continue + raise last_error session_id = self._generate_session_id() scene_id = str(uuid.uuid4()) diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 6865085..03ebf1f 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -3,7 +3,7 @@ import base64 import json import time -from typing import Optional, AsyncGenerator, List, Dict, Any +from typing import Optional, AsyncGenerator, List, Dict, Any, Tuple from ..core.logger import debug_logger from ..core.config import config from ..core.models import Task, RequestLog @@ -591,6 +591,226 @@ def __init__(self, flow_client, token_manager, load_balancer, db, concurrency_ma proxy_manager=proxy_manager ) + @staticmethod + def _normalize_text(value: Optional[str]) -> str: + return (value or "").strip().lower().replace("_", "-") + + @staticmethod + def _normalize_aspect_ratio(aspect_ratio: Optional[str]) -> str: + value = GenerationHandler._normalize_text(aspect_ratio) + mapping = { + "landscape": "landscape", + "horizontal": "landscape", + "16:9": "landscape", + "portrait": "portrait", + "vertical": "portrait", + "9:16": "portrait", + "square": "square", + "1:1": "square", + "four-three": "four-three", + "4:3": "four-three", + "three-four": "three-four", + "3:4": "three-four" + } + return mapping.get(value, "landscape") + + @staticmethod + def _normalize_resolution(resolution: Optional[str]) -> Optional[str]: + value = GenerationHandler._normalize_text(resolution) + if value in ("2k", "4k", "1080p"): + return value + return None + + @staticmethod + def _normalize_quality(quality: Optional[str]) -> str: + value = GenerationHandler._normalize_text(quality) + mapping = { + "standard": "standard", + "normal": "standard", + "default": "standard", + "ultra": "ultra", + "ultra-relaxed": "ultra_relaxed", + "ultra_relaxed": "ultra_relaxed", + "relaxed": "ultra_relaxed" + } + return mapping.get(value, "standard") + + @staticmethod + def _normalize_video_type(video_type: Optional[str]) -> Optional[str]: + value = GenerationHandler._normalize_text(video_type) + if value in ("t2v", "i2v", "r2v"): + return value + return None + + def resolve_model( + self, + model: str, + images: Optional[List[bytes]] = None, + aspect_ratio: Optional[str] = None, + resolution: Optional[str] = None, + quality: Optional[str] = None, + video_type: Optional[str] = None + ) -> Tuple[str, Dict[str, Any]]: + """解析模型:优先精确匹配,其次根据通用模型+参数推断。""" + if model in MODEL_CONFIG: + return model, MODEL_CONFIG[model] + + normalized_model = self._normalize_text(model) + aspect = self._normalize_aspect_ratio(aspect_ratio) + res = self._normalize_resolution(resolution) + q = self._normalize_quality(quality) + + image_count = len(images) if images else 0 + inferred_video_type = self._normalize_video_type(video_type) + if not inferred_video_type: + if image_count == 0: + inferred_video_type = "t2v" + elif image_count <= 2: + inferred_video_type = "i2v" + else: + inferred_video_type = "r2v" + + # ===== 图片通用模型 ===== + if any(k in normalized_model for k in ["gemini-2.5", "gemini-25"]) and "image" in normalized_model: + suffix = "portrait" if aspect == "portrait" else "landscape" + resolved = f"gemini-2.5-flash-image-{suffix}" + if resolved in MODEL_CONFIG: + return resolved, MODEL_CONFIG[resolved] + + if any(k in normalized_model for k in ["gemini-3.0", "gemini-30"]) and "image" in normalized_model: + aspect_map = { + "landscape": "landscape", + "portrait": "portrait", + "square": "square", + "four-three": "four-three", + "three-four": "three-four" + } + aspect_suffix = aspect_map.get(aspect, "landscape") + resolved = f"gemini-3.0-pro-image-{aspect_suffix}" + if res in ("2k", "4k"): + resolved = f"{resolved}-{res}" + if resolved in MODEL_CONFIG: + return resolved, MODEL_CONFIG[resolved] + + if "imagen" in normalized_model and "image" in normalized_model: + suffix = "portrait" if aspect == "portrait" else "landscape" + resolved = f"imagen-4.0-generate-preview-{suffix}" + if resolved in MODEL_CONFIG: + return resolved, MODEL_CONFIG[resolved] + + if normalized_model in ("flow-image", "generic-image", "image"): + aspect_map = { + "landscape": "landscape", + "portrait": "portrait", + "square": "square", + "four-three": "four-three", + "three-four": "three-four" + } + aspect_suffix = aspect_map.get(aspect, "landscape") + resolved = f"gemini-3.0-pro-image-{aspect_suffix}" + if res in ("2k", "4k"): + resolved = f"{resolved}-{res}" + if resolved in MODEL_CONFIG: + return resolved, MODEL_CONFIG[resolved] + + # ===== 视频通用模型 ===== + family = "veo_3_1" + if "veo-2.1" in normalized_model or "veo_2_1" in model: + family = "veo_2_1" + elif "veo-2.0" in normalized_model or "veo_2_0" in model: + family = "veo_2_0" + + is_portrait = (aspect == "portrait") + + def pick(*candidates): + for c in candidates: + if c in MODEL_CONFIG: + return c + return None + + if normalized_model in ("flow-video", "generic-video", "video") or "veo" in normalized_model: + resolved = None + + if inferred_video_type == "t2v": + if family == "veo_3_1": + if res in ("4k", "1080p"): + suffix = "portrait" if is_portrait else "" + if q == "ultra": + resolved = pick( + f"veo_3_1_t2v_fast_{suffix + '_' if suffix else ''}ultra_{res}".replace("__", "_"), + f"veo_3_1_t2v_fast_{res}" + ) + else: + resolved = pick( + f"veo_3_1_t2v_fast_{suffix + '_' if suffix else ''}{res}".replace("__", "_"), + f"veo_3_1_t2v_fast_{res}" + ) + else: + if q == "ultra_relaxed": + resolved = pick( + "veo_3_1_t2v_fast_portrait_ultra_relaxed" if is_portrait else "veo_3_1_t2v_fast_ultra_relaxed" + ) + elif q == "ultra": + resolved = pick( + "veo_3_1_t2v_fast_portrait_ultra" if is_portrait else "veo_3_1_t2v_fast_ultra" + ) + else: + resolved = pick( + "veo_3_1_t2v_fast_portrait" if is_portrait else "veo_3_1_t2v_fast_landscape" + ) + elif family == "veo_2_1": + resolved = pick("veo_2_1_fast_d_15_t2v_portrait" if is_portrait else "veo_2_1_fast_d_15_t2v_landscape") + else: + resolved = pick("veo_2_0_t2v_portrait" if is_portrait else "veo_2_0_t2v_landscape") + + elif inferred_video_type == "i2v": + if family == "veo_3_1": + if res in ("4k", "1080p"): + resolved = pick( + f"veo_3_1_i2v_s_fast_portrait_ultra_fl_{res}" if is_portrait else f"veo_3_1_i2v_s_fast_ultra_fl_{res}" + ) + else: + if q == "ultra_relaxed": + resolved = pick( + "veo_3_1_i2v_s_fast_portrait_ultra_relaxed" if is_portrait else "veo_3_1_i2v_s_fast_ultra_relaxed" + ) + elif q == "ultra": + resolved = pick( + "veo_3_1_i2v_s_fast_portrait_ultra_fl" if is_portrait else "veo_3_1_i2v_s_fast_ultra_fl" + ) + else: + resolved = pick( + "veo_3_1_i2v_s_fast_portrait_fl" if is_portrait else "veo_3_1_i2v_s_fast_fl" + ) + elif family == "veo_2_1": + resolved = pick("veo_2_1_fast_d_15_i2v_portrait" if is_portrait else "veo_2_1_fast_d_15_i2v_landscape") + else: + resolved = pick("veo_2_0_i2v_portrait" if is_portrait else "veo_2_0_i2v_landscape") + + elif inferred_video_type == "r2v": + if res in ("4k", "1080p"): + resolved = pick( + f"veo_3_1_r2v_fast_portrait_ultra_{res}" if is_portrait else f"veo_3_1_r2v_fast_ultra_{res}" + ) + else: + if q == "ultra_relaxed": + resolved = pick( + "veo_3_1_r2v_fast_portrait_ultra_relaxed" if is_portrait else "veo_3_1_r2v_fast_ultra_relaxed" + ) + elif q == "ultra": + resolved = pick( + "veo_3_1_r2v_fast_portrait_ultra" if is_portrait else "veo_3_1_r2v_fast_ultra" + ) + else: + resolved = pick( + "veo_3_1_r2v_fast_portrait" if is_portrait else "veo_3_1_r2v_fast" + ) + + if resolved: + return resolved, MODEL_CONFIG[resolved] + + raise ValueError(f"不支持的模型: {model}") + async def check_token_availability(self, is_image: bool, is_video: bool) -> bool: """检查Token可用性 @@ -612,7 +832,14 @@ async def handle_generation( model: str, prompt: str, images: Optional[List[bytes]] = None, - stream: bool = False + stream: bool = False, + aspect_ratio: Optional[str] = None, + resolution: Optional[str] = None, + quality: Optional[str] = None, + video_type: Optional[str] = None, + image_count: int = 1, + image_style: Optional[str] = None, + image_seed: Optional[int] = None ) -> AsyncGenerator: """统一生成入口 @@ -625,16 +852,32 @@ async def handle_generation( start_time = time.time() token = None - # 1. 验证模型 - if model not in MODEL_CONFIG: - error_msg = f"不支持的模型: {model}" + model_config = None + + # 1. 验证/推断模型 + try: + resolved_model, model_config = self.resolve_model( + model=model, + images=images, + aspect_ratio=aspect_ratio, + resolution=resolution, + quality=quality, + video_type=video_type + ) + except ValueError as e: + error_msg = str(e) debug_logger.log_error(error_msg) yield self._create_error_response(error_msg) return - model_config = MODEL_CONFIG[model] generation_type = model_config["type"] - debug_logger.log_info(f"[GENERATION] 开始生成 - 模型: {model}, 类型: {generation_type}, Prompt: {prompt[:50]}...") + if resolved_model != model: + debug_logger.log_info( + f"[GENERATION] 通用模型推断: {model} -> {resolved_model} " + f"(aspect_ratio={aspect_ratio}, resolution={resolution}, quality={quality}, video_type={video_type}, images={len(images) if images else 0})" + ) + + debug_logger.log_info(f"[GENERATION] 开始生成 - 模型: {resolved_model}, 类型: {generation_type}, Prompt: {prompt[:50]}...") # 非流式模式: 只检查可用性 if not stream: @@ -667,9 +910,9 @@ async def handle_generation( debug_logger.log_info(f"[GENERATION] 正在选择可用Token...") if generation_type == "image": - token = await self.load_balancer.select_token(for_image_generation=True, model=model) + token = await self.load_balancer.select_token(for_image_generation=True, model=resolved_model) else: - token = await self.load_balancer.select_token(for_video_generation=True, model=model) + token = await self.load_balancer.select_token(for_video_generation=True, model=resolved_model) if not token: error_msg = self._get_no_token_error_message(generation_type) @@ -705,16 +948,35 @@ async def handle_generation( debug_logger.log_info(f"[GENERATION] Project ID: {project_id}") # 5. 根据类型处理 + request_context = { + "generated_url": None + } + if generation_type == "image": debug_logger.log_info(f"[GENERATION] 开始图片生成流程...") async for chunk in self._handle_image_generation( - token, project_id, model_config, prompt, images, stream + token, + project_id, + model_config, + prompt, + images, + stream, + image_count=image_count, + image_style=image_style, + image_seed=image_seed, + request_context=request_context ): yield chunk else: # video debug_logger.log_info(f"[GENERATION] 开始视频生成流程...") async for chunk in self._handle_video_generation( - token, project_id, model_config, prompt, images, stream + token, + project_id, + model_config, + prompt, + images, + stream, + request_context=request_context ): yield chunk @@ -733,20 +995,31 @@ async def handle_generation( # 构建响应数据,包含生成的URL response_data = { "status": "success", - "model": model, + "model": resolved_model, + "requested_model": model, "prompt": prompt[:100] } # 添加生成的URL(如果有) - if hasattr(self, '_last_generated_url') and self._last_generated_url: - response_data["url"] = self._last_generated_url - # 清除临时存储 - self._last_generated_url = None + if request_context.get("generated_url"): + response_data["url"] = request_context["generated_url"] await self._log_request( token.id, f"generate_{generation_type}", - {"model": model, "prompt": prompt[:100], "has_images": images is not None and len(images) > 0}, + { + "model": resolved_model, + "requested_model": model, + "prompt": prompt[:100], + "has_images": images is not None and len(images) > 0, + "aspect_ratio": aspect_ratio, + "resolution": resolution, + "quality": quality, + "video_type": video_type, + "image_count": image_count, + "image_style": image_style, + "image_seed": image_seed + }, response_data, 200, duration @@ -767,7 +1040,18 @@ async def handle_generation( await self._log_request( token.id if token else None, f"generate_{generation_type if model_config else 'unknown'}", - {"model": model, "prompt": prompt[:100], "has_images": images is not None and len(images) > 0}, + { + "model": model, + "prompt": prompt[:100], + "has_images": images is not None and len(images) > 0, + "aspect_ratio": aspect_ratio, + "resolution": resolution, + "quality": quality, + "video_type": video_type, + "image_count": image_count, + "image_style": image_style, + "image_seed": image_seed + }, {"error": error_msg}, 500, duration @@ -787,17 +1071,23 @@ async def _handle_image_generation( model_config: dict, prompt: str, images: Optional[List[bytes]], - stream: bool + stream: bool, + image_count: int = 1, + image_style: Optional[str] = None, + image_seed: Optional[int] = None, + request_context: Optional[dict] = None ) -> AsyncGenerator: """处理图片生成 (同步返回)""" - # 获取并发槽位 - if self.concurrency_manager: - if not await self.concurrency_manager.acquire_image(token.id): - yield self._create_error_response("图片并发限制已达上限") - return - + slot_acquired = False try: + # 获取并发槽位 + if self.concurrency_manager: + if not await self.concurrency_manager.acquire_image(token.id): + yield self._create_error_response("图片并发限制已达上限") + return + slot_acquired = True + # 上传图片 (如果有) image_inputs = [] if images and len(images) > 0: @@ -828,7 +1118,10 @@ async def _handle_image_generation( prompt=prompt, model_name=model_config["model_name"], aspect_ratio=model_config["aspect_ratio"], - image_inputs=image_inputs + image_inputs=image_inputs, + image_count=image_count, + image_style=image_style, + image_seed=image_seed ) # 提取URL和mediaId @@ -867,7 +1160,8 @@ async def _handle_image_generation( # 缓存放大后的图片 (如果启用) # 日志统一记录原图URL (放大后的base64数据太大,不适合存储) - self._last_generated_url = image_url + if request_context is not None: + request_context["generated_url"] = image_url if config.cache_enabled: try: @@ -949,8 +1243,9 @@ async def _handle_image_generation( yield self._create_stream_chunk("缓存已关闭,正在返回源链接...\n") # 返回结果 - # 存储URL用于日志记录 - self._last_generated_url = local_url + # 存储URL用于日志记录(请求级上下文,避免并发覆盖) + if request_context is not None: + request_context["generated_url"] = local_url if stream: yield self._create_stream_chunk( @@ -965,7 +1260,7 @@ async def _handle_image_generation( finally: # 释放并发槽位 - if self.concurrency_manager: + if self.concurrency_manager and slot_acquired: await self.concurrency_manager.release_image(token.id) async def _handle_video_generation( @@ -975,17 +1270,21 @@ async def _handle_video_generation( model_config: dict, prompt: str, images: Optional[List[bytes]], - stream: bool + stream: bool, + request_context: Optional[dict] = None ) -> AsyncGenerator: """处理视频生成 (异步轮询)""" - # 获取并发槽位 - if self.concurrency_manager: - if not await self.concurrency_manager.acquire_video(token.id): - yield self._create_error_response("视频并发限制已达上限") - return - + slot_state = {"video_slot_acquired": False} try: + # 获取并发槽位 + if self.concurrency_manager: + acquired = await self.concurrency_manager.acquire_video(token.id) + slot_state["video_slot_acquired"] = acquired + if not acquired: + yield self._create_error_response("视频并发限制已达上限") + return + # 获取模型类型和配置 video_type = model_config.get("video_type") supports_images = model_config.get("supports_images", False) @@ -1005,8 +1304,8 @@ async def _handle_video_generation( # veo_3_1_t2v_fast -> veo_3_1_t2v_fast_ultra # veo_3_1_t2v_fast_portrait -> veo_3_1_t2v_fast_portrait_ultra # veo_3_0_r2v_fast -> veo_3_0_r2v_fast_ultra - if "_fl" in model_key: - model_key = model_key.replace("_fl", "_ultra_fl") + if model_key.endswith("_fl"): + model_key = model_key[:-3] + "_ultra_fl" else: # 直接在末尾添加 _ultra model_key = model_key + "_ultra" @@ -1189,12 +1488,19 @@ async def _handle_video_generation( # 检查是否需要放大 upsample_config = model_config.get("upsample") - async for chunk in self._poll_video_result(token, project_id, operations, stream, upsample_config): + async for chunk in self._poll_video_result( + token, + project_id, + operations, + stream, + upsample_config, + slot_state=slot_state + ): yield chunk finally: # 释放并发槽位 - if self.concurrency_manager: + if self.concurrency_manager and slot_state.get("video_slot_acquired"): await self.concurrency_manager.release_video(token.id) async def _poll_video_result( @@ -1203,7 +1509,8 @@ async def _poll_video_result( project_id: str, operations: List[Dict], stream: bool, - upsample_config: Optional[Dict] = None + upsample_config: Optional[Dict] = None, + slot_state: Optional[Dict] = None ) -> AsyncGenerator: """轮询视频生成结果 @@ -1272,9 +1579,19 @@ async def _poll_video_result( if stream: yield self._create_stream_chunk("放大任务已提交,继续轮询...\n") + # 放大阶段通常耗时更长,此处提前释放视频并发槽位 + if self.concurrency_manager and slot_state and slot_state.get("video_slot_acquired"): + await self.concurrency_manager.release_video(token.id) + slot_state["video_slot_acquired"] = False + # 递归轮询放大结果(不再放大) async for chunk in self._poll_video_result( - token, project_id, upsample_operations, stream, None + token, + project_id, + upsample_operations, + stream, + None, + slot_state=slot_state ): yield chunk return @@ -1316,8 +1633,9 @@ async def _poll_video_result( completed_at=time.time() ) - # 存储URL用于日志记录 - self._last_generated_url = local_url + # 存储URL用于日志记录(请求级上下文,避免并发覆盖) + if request_context is not None: + request_context["generated_url"] = local_url # 返回结果 if stream: diff --git a/src/services/semantic_probe.py b/src/services/semantic_probe.py new file mode 100644 index 0000000..2a59410 --- /dev/null +++ b/src/services/semantic_probe.py @@ -0,0 +1,130 @@ +"""Semantic probe service for auto-inferring generation parameters via configurable Chat API.""" +import json +import re +from typing import Optional, Dict, Any +from curl_cffi.requests import AsyncSession + +from ..core.config import config +from ..core.logger import debug_logger + + +class SemanticProbeService: + """Use external Chat API to infer aspect_ratio/resolution/video_type/quality from prompt.""" + + @staticmethod + def _extract_json(text: str) -> Optional[Dict[str, Any]]: + if not text: + return None + + # Try direct JSON first + try: + data = json.loads(text) + if isinstance(data, dict): + return data + except Exception: + pass + + # Try fenced JSON block + match = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", text, re.IGNORECASE) + if match: + try: + data = json.loads(match.group(1)) + if isinstance(data, dict): + return data + except Exception: + pass + + # Try first object in text + match = re.search(r"(\{[\s\S]*\})", text) + if match: + try: + data = json.loads(match.group(1)) + if isinstance(data, dict): + return data + except Exception: + pass + + return None + + async def infer( + self, + prompt: str, + has_images: bool, + current: Dict[str, Optional[str]] + ) -> Dict[str, Optional[str]]: + """Infer missing params. Returns empty dict if probe disabled/unavailable/failed.""" + if not config.semantic_probe_enabled: + return {} + + api_url = config.semantic_probe_api_url + api_key = config.semantic_probe_api_key + model = config.semantic_probe_model + + if not api_url or not model: + return {} + + system_prompt = ( + "你是参数推断器。请根据用户生成意图输出JSON,字段仅包含:" + "aspect_ratio,resolution,quality,video_type。\n" + "取值约束:\n" + "- aspect_ratio: landscape|portrait|square|four-three|three-four\n" + "- resolution: 2k|4k|1080p|null\n" + "- quality: standard|ultra|ultra_relaxed|null\n" + "- video_type: t2v|i2v|r2v|null\n" + "只输出JSON,不要解释。" + ) + + user_prompt = { + "prompt": prompt, + "has_images": has_images, + "current": current + } + + payload = { + "model": model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": json.dumps(user_prompt, ensure_ascii=False)} + ], + "temperature": 0, + "stream": False + } + + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + try: + async with AsyncSession() as session: + resp = await session.post( + api_url, + headers=headers, + json=payload, + timeout=config.semantic_probe_timeout, + impersonate="chrome110" + ) + + if resp.status_code >= 400: + debug_logger.log_warning(f"[SEMANTIC_PROBE] HTTP {resp.status_code}: {resp.text[:300]}") + return {} + + result = resp.json() + content = ( + result.get("choices", [{}])[0] + .get("message", {}) + .get("content", "") + ) + + data = self._extract_json(content) + if not data: + return {} + + return { + "aspect_ratio": data.get("aspect_ratio"), + "resolution": data.get("resolution"), + "quality": data.get("quality"), + "video_type": data.get("video_type") + } + except Exception as e: + debug_logger.log_warning(f"[SEMANTIC_PROBE] 推断失败,已回退本地逻辑: {str(e)}") + return {} diff --git a/src/services/token_manager.py b/src/services/token_manager.py index e5e40a0..70492ef 100644 --- a/src/services/token_manager.py +++ b/src/services/token_manager.py @@ -456,7 +456,7 @@ async def ensure_project_exists(self, token_id: int) -> str: async def record_usage(self, token_id: int, is_video: bool = False): """Record token usage""" - await self.db.update_token(token_id, use_count=1, last_used_at=datetime.now()) + await self.db.increment_token_use_count(token_id) if is_video: await self.db.increment_token_stats(token_id, "video")