From fe748daca7ef5808e6c5f63fddb0647713425cc0 Mon Sep 17 00:00:00 2001 From: 827dls <1670704430@qq.com> Date: Sun, 17 May 2026 16:46:16 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E8=A7=86=E8=A7=89=E8=AF=AD=E8=A8=80?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=88=86=E7=B1=BB=20=E4=BB=A5=E5=8F=8A?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=9F=B3=E9=A2=91=E7=90=86=E8=A7=A3=E5=B7=A5?= =?UTF-8?q?=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/agents/create_agent_info.py | 62 +++++-- backend/consts/const.py | 2 + backend/consts/model.py | 10 ++ .../managed_system_prompt_template_en.yaml | 2 +- .../manager_system_prompt_template_en.yaml | 2 +- backend/services/agent_service.py | 2 +- backend/services/config_sync_service.py | 20 ++- backend/services/image_service.py | 30 +++- backend/services/model_health_service.py | 2 +- backend/services/model_management_service.py | 18 +- .../services/providers/silicon_provider.py | 76 +++++++- .../services/tool_configuration_service.py | 20 ++- docker/.env.bak | 168 ----------------- .../components/agentConfig/ToolManagement.tsx | 43 ++++- .../chat/components/chatAttachment.tsx | 10 +- .../[locale]/chat/components/chatInput.tsx | 22 ++- .../[locale]/chat/internal/chatInterface.tsx | 16 +- .../components/model/ModelAddDialog.tsx | 70 +++++++- .../components/model/ModelDeleteDialog.tsx | 31 +++- .../components/model/ModelEditDialog.tsx | 12 +- .../models/components/modelConfig.tsx | 46 ++++- frontend/const/chatConfig.ts | 14 +- frontend/const/modelConfig.ts | 2 + frontend/hooks/model/useDashscopeModelList.ts | 4 +- frontend/hooks/model/useTokenponyModelList.ts | 4 +- frontend/hooks/useConfig.ts | 44 ++++- frontend/lib/chat/chatAttachmentUtils.ts | 13 ++ frontend/public/locales/en/common.json | 8 +- frontend/public/locales/zh/common.json | 8 +- ...{tailwind.config.ts => tailwind.config.js} | 0 frontend/types/modelConfig.ts | 4 + sdk/nexent/core/agents/nexent_agent.py | 4 +- sdk/nexent/core/models/openai_vlm.py | 61 +++++++ sdk/nexent/core/prompts/analyze_audio_en.yaml | 13 ++ sdk/nexent/core/prompts/analyze_audio_zh.yaml | 13 ++ sdk/nexent/core/prompts/analyze_video_en.yaml | 13 ++ sdk/nexent/core/prompts/analyze_video_zh.yaml | 13 ++ sdk/nexent/core/tools/__init__.py | 4 + sdk/nexent/core/tools/analyze_audio_tool.py | 169 ++++++++++++++++++ sdk/nexent/core/tools/analyze_image_tool.py | 20 +-- sdk/nexent/core/tools/analyze_video_tool.py | 156 ++++++++++++++++ .../core/utils/prompt_template_utils.py | 12 +- sdk/nexent/multi_modal/utils.py | 20 ++- test/backend/agents/test_create_agent_info.py | 134 +++++++++++++- .../providers/test_silicon_provider.py | 72 +++++++- test/backend/services/test_image_service.py | 82 +++++++-- .../services/test_model_management_service.py | 50 +++--- .../test_tool_configuration_service.py | 108 ++++++++++- test/common/test_mocks.py | 2 + test/sdk/core/agents/test_nexent_agent.py | 46 +++++ test/sdk/core/models/test_openai_vlm.py | 60 ++++++- .../tools/test_analyze_audio_video_tool.py | 119 ++++++++++++ .../sdk/core/tools/test_analyze_image_tool.py | 6 +- .../core/utils/test_prompt_template_utils.py | 24 ++- test/sdk/multi_modal/test_load_save_object.py | 17 ++ 55 files changed, 1674 insertions(+), 309 deletions(-) delete mode 100644 docker/.env.bak rename frontend/{tailwind.config.ts => tailwind.config.js} (100%) create mode 100644 sdk/nexent/core/prompts/analyze_audio_en.yaml create mode 100644 sdk/nexent/core/prompts/analyze_audio_zh.yaml create mode 100644 sdk/nexent/core/prompts/analyze_video_en.yaml create mode 100644 sdk/nexent/core/prompts/analyze_video_zh.yaml create mode 100644 sdk/nexent/core/tools/analyze_audio_tool.py create mode 100644 sdk/nexent/core/tools/analyze_video_tool.py create mode 100644 test/sdk/core/tools/test_analyze_audio_video_tool.py diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 5a11b550b..9033f9ba0 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -21,7 +21,7 @@ from database.a2a_agent_db import PROTOCOL_JSONRPC from services.memory_config_service import build_memory_context -from services.image_service import get_vlm_model +from services.image_service import get_video_understanding_model, get_vlm_model from database.agent_db import search_agent_info_by_agent_id, query_sub_agents_id_list from database.agent_version_db import query_current_version_no from database.tool_db import search_tools_for_sub_agent @@ -31,13 +31,36 @@ from utils.model_name_utils import add_repo_to_name from utils.prompt_template_utils import get_agent_prompt_template from utils.config_utils import tenant_config_manager, get_model_name_from_config -from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE +from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET from consts.exceptions import ValidationError logger = logging.getLogger("create_agent_info") logger.setLevel(logging.DEBUG) +def _build_internal_s3_url(file: dict) -> str: + """Build a valid S3 URL for internal tools from uploaded file metadata.""" + if not isinstance(file, dict): + return "" + + object_name = str(file.get("object_name") or "").strip().lstrip("/") + if object_name: + bucket = MINIO_DEFAULT_BUCKET or "nexent" + return f"s3://{bucket}/{object_name}" + + url = str(file.get("url") or "").strip() + if not url or url.startswith("blob:") or url.startswith("s3:/blob:"): + return "" + + if url.startswith("s3://"): + return url + + if url.startswith("s3:/"): + return "s3://" + url.replace("s3:/", "", 1).lstrip("/") + + return "s3:/" + url + + def _get_skills_for_template( agent_id: int, tenant_id: str, @@ -526,10 +549,17 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int } elif tool_config.class_name == "AnalyzeImageTool": tool_config.metadata = { + # get_vlm_model reads the first multimodal slot, now shown as image understanding. "vlm_model": get_vlm_model(tenant_id=tenant_id), "storage_client": minio_client, "validate_url_access": lambda urls: validate_urls_access(urls, user_id) } + elif tool_config.class_name in ["AnalyzeAudioTool", "AnalyzeVideoTool"]: + tool_config.metadata = { + "vlm_model": get_video_understanding_model(tenant_id=tenant_id), + "storage_client": minio_client, + "validate_url_access": lambda urls: validate_urls_access(urls, user_id) + } tool_config_list.append(tool_config) @@ -630,10 +660,12 @@ async def join_minio_file_description_to_query( # Collect files from current message first (higher priority) if minio_files and isinstance(minio_files, list): for file in minio_files: - if isinstance(file, dict) and file.get("url") and file.get("name"): - url = file["url"] - if url not in seen_urls: - seen_urls.add(url) + if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")): + s3_url = _build_internal_s3_url(file) + if not s3_url: + continue + if s3_url not in seen_urls: + seen_urls.add(s3_url) all_files.append(file) # Collect files from historical messages (lower priority, already-deduped) @@ -641,10 +673,12 @@ async def join_minio_file_description_to_query( for msg in history: if isinstance(msg, dict) and msg.get("minio_files"): for file in msg["minio_files"]: - if isinstance(file, dict) and file.get("url") and file.get("name"): - url = file["url"] - if url not in seen_urls: - seen_urls.add(url) + if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")): + s3_url = _build_internal_s3_url(file) + if not s3_url: + continue + if s3_url not in seen_urls: + seen_urls.add(s3_url) all_files.append(file) # Enforce file count limit (keep most recent files by truncating from the end) @@ -660,7 +694,7 @@ async def join_minio_file_description_to_query( fixed_overhead = len(prefix) + len(suffix) for i, file in enumerate(all_files): - s3_url = f"s3:/{file['url']}" + s3_url = _build_internal_s3_url(file) presigned_url = file.get("presigned_url", "") # Build description with both URLs @@ -712,8 +746,10 @@ def _format_minio_files_for_content(minio_files: Optional[List[dict]], max_files if i >= max_files: file_lines.append(f" - ... (and {len(minio_files) - max_files} more files)") break - if isinstance(file, dict) and file.get("url") and file.get("name"): - s3_url = f"s3:/{file['url']}" + if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")): + s3_url = _build_internal_s3_url(file) + if not s3_url: + continue presigned_url = file.get("presigned_url", "") if presigned_url: file_lines.append( diff --git a/backend/consts/const.py b/backend/consts/const.py index 77e86a185..6c8a43266 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -305,6 +305,8 @@ class VectorDatabaseType(str, Enum): "multiEmbedding": "MULTI_EMBEDDING_ID", "rerank": "RERANK_ID", "vlm": "VLM_ID", + "vlm2": "VLM2_ID", + "vlm3": "VLM3_ID", "stt": "STT_ID", "tts": "TTS_ID" } diff --git a/backend/consts/model.py b/backend/consts/model.py index bcaffcae7..31f611f8c 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -160,12 +160,22 @@ class STTModelConfig(BaseModel): accessToken: Optional[str] = None +def _empty_model_config() -> SingleModelConfig: + return SingleModelConfig( + modelName="", + displayName="", + apiConfig=ModelApiConfig(apiKey="", modelUrl="") + ) + + class ModelConfig(BaseModel): llm: SingleModelConfig embedding: SingleModelConfig multiEmbedding: SingleModelConfig rerank: SingleModelConfig vlm: SingleModelConfig + vlm2: SingleModelConfig = Field(default_factory=_empty_model_config) + vlm3: SingleModelConfig = Field(default_factory=_empty_model_config) stt: STTModelConfig diff --git a/backend/prompts/managed_system_prompt_template_en.yaml b/backend/prompts/managed_system_prompt_template_en.yaml index 67da8305c..94b35f66d 100644 --- a/backend/prompts/managed_system_prompt_template_en.yaml +++ b/backend/prompts/managed_system_prompt_template_en.yaml @@ -116,7 +116,7 @@ system_prompt: |- → Use **presigned_url** (already includes proxy prefix, format: `http://.../api/nb/v1/file/fetch?presigned_url=...`) Directly use the **presigned_url** field provided in the user's uploaded file info. No need to construct or append anything. 2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image): - → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`) + → Use **S3 URL** (format: `s3://nexent/attachments/xxx.pdf`) Reason: Internal tools run inside Nexent and can directly access MinIO storage {%- else %} diff --git a/backend/prompts/manager_system_prompt_template_en.yaml b/backend/prompts/manager_system_prompt_template_en.yaml index a4ffae074..feccecc9c 100644 --- a/backend/prompts/manager_system_prompt_template_en.yaml +++ b/backend/prompts/manager_system_prompt_template_en.yaml @@ -118,7 +118,7 @@ system_prompt: |- → Use **Download URL** (format: `https://minio.example.com/...?token=xxx`) Reason: MCP tools run on external services and cannot access internal S3 storage 2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image): - → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`) + → Use **S3 URL** (format: `s3://nexent/attachments/xxx.pdf`) Reason: Internal tools run inside Nexent and can directly access MinIO storage {%- else %} - No tools are currently available diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 02fa7d8c6..080c438a8 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -1152,7 +1152,7 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) # Check if any tool is KnowledgeBaseSearchTool and set its metadata to empty dict for tool in tool_list: - if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "DataMateSearchTool"]: + if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "AnalyzeAudioTool", "AnalyzeVideoTool", "DataMateSearchTool"]: tool.metadata = {} # Get model_id and model display name from agent_info diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index 0ed29bfc5..2a585fd50 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -20,7 +20,7 @@ MODEL_ENGINE_ENABLED, TENANT_NAME ) -from database.model_management_db import get_model_id_by_display_name +from database.model_management_db import get_model_id_by_display_name, get_model_records from utils.config_utils import ( get_env_key, get_model_name_from_config, @@ -31,6 +31,20 @@ logger = logging.getLogger("config_sync_service") +def get_model_id_for_config(model_type: str, display_name: str, tenant_id: str) -> Optional[int]: + if not display_name: + return None + + records = get_model_records( + {"display_name": display_name, "model_type": model_type}, + tenant_id + ) + if records: + return records[0].get("model_id") + + return get_model_id_by_display_name(display_name, tenant_id) + + def handle_model_config(tenant_id: str, user_id: str, config_key: str, model_id: Optional[int], tenant_config_dict: dict) -> None: """ Handle model configuration updates, deletions, and settings operations @@ -98,8 +112,8 @@ async def save_config_impl(config, tenant_id, user_id): model_display_name = model_config.get("displayName") config_key = get_env_key(model_type) + "_ID" - model_id = get_model_id_by_display_name( - model_display_name, tenant_id) + model_id = get_model_id_for_config( + model_type, model_display_name, tenant_id) handle_model_config(tenant_id, user_id, config_key, model_id, tenant_config_dict) diff --git a/backend/services/image_service.py b/backend/services/image_service.py index 8decbd541..8a924e9cc 100644 --- a/backend/services/image_service.py +++ b/backend/services/image_service.py @@ -31,7 +31,11 @@ async def proxy_image_impl(decoded_url: str): def get_vlm_model(tenant_id: str): - # Get the tenant config + """Return the configured image understanding model for AnalyzeImageTool. + + The first multimodal model slot is still stored under MODEL_CONFIG_MAPPING["vlm"] + for compatibility, but it is the user-facing image understanding configuration. + """ vlm_model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["vlm"], tenant_id=tenant_id) if not vlm_model_config: @@ -48,3 +52,27 @@ def get_vlm_model(tenant_id: str): max_tokens=512, ssl_verify=vlm_model_config.get("ssl_verify", True), ) + + +def get_image_understanding_model(tenant_id: str): + return get_vlm_model(tenant_id=tenant_id) + + +def get_video_understanding_model(tenant_id: str): + """Return the configured video understanding model for multimodal tools.""" + vlm_model_config = tenant_config_manager.get_model_config( + key=MODEL_CONFIG_MAPPING["vlm3"], tenant_id=tenant_id) + if not vlm_model_config: + return None + return OpenAIVLModel( + observer=MessageObserver(), + model_id=get_model_name_from_config( + vlm_model_config) if vlm_model_config else "", + api_base=vlm_model_config.get("base_url", ""), + api_key=vlm_model_config.get("api_key", ""), + temperature=0.7, + top_p=0.7, + frequency_penalty=0.5, + max_tokens=512, + ssl_verify=vlm_model_config.get("ssl_verify", True), + ) diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index a20b2a6ca..5044f50f0 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -125,7 +125,7 @@ async def _perform_connectivity_check( ssl_verify=ssl_verify, ) connectivity = await rerank_model.connectivity_check() - elif model_type == "vlm": + elif model_type in ("vlm", "vlm2", "vlm3"): observer = MessageObserver() set_monitoring_operation("connectivity_check", display_name=display_name) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index d012803be..268c00413 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -153,6 +153,13 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay tenant_id, provider, model_type) model_list_ids = {model.get("id") for model in model_list} if model_list else set() + existing_model_map = { + add_repo_to_name( + model_repo=model["model_repo"], + model_name=model["model_name"], + ): model + for model in existing_model_list + } # Delete existing models not present for model in existing_model_list: @@ -162,21 +169,20 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay # Create or update new models for model in model_list: + model["model_type"] = model_type _, model_name = split_repo_name( model["id"]) if model.get("id") else ("", "") model_repo, model_name_only = split_repo_name( model.get("id", "")) if model.get("id") else ("", "") model_display_name = add_repo_to_name(model_repo, model_name_only) if model_name: - existing_model_by_display = get_model_by_display_name( - model_display_name, tenant_id) - if existing_model_by_display: + existing_model = existing_model_map.get(model_display_name) + if existing_model: # Check if max_tokens has changed - existing_max_tokens = existing_model_by_display.get( - "max_tokens") + existing_max_tokens = existing_model.get("max_tokens") new_max_tokens = model.get("max_tokens") if new_max_tokens is not None and existing_max_tokens != new_max_tokens: - update_model_record(existing_model_by_display["model_id"], { + update_model_record(existing_model["model_id"], { "max_tokens": new_max_tokens}, user_id) continue diff --git a/backend/services/providers/silicon_provider.py b/backend/services/providers/silicon_provider.py index ea41cc95d..130f2346e 100644 --- a/backend/services/providers/silicon_provider.py +++ b/backend/services/providers/silicon_provider.py @@ -1,4 +1,5 @@ import httpx +import re from typing import Dict, List from consts.const import DEFAULT_LLM_MAX_TOKENS @@ -6,6 +7,62 @@ from services.providers.base import AbstractModelProvider, _classify_provider_error +SILICON_VLM_MODEL_KEYWORDS = ( + "-vl", + "_vl", + "/vl", + ".vl", + "vl-", + "vision", + "visual", + "internvl", + "deepseek-vl", + "deepseekvl", + "glm-4v", + "minicpm-v", + "llava", + "kimi-vl", + "kimi-k2.5", + "kimi-k2.6", + "qvq", + "omni", + "qwen3.5", + "qwen3.6", +) + +SILICON_VLM_METADATA_KEYWORDS = ("image", "video", "vision", "visual") + + +def _contains_silicon_vlm_metadata(value) -> bool: + if isinstance(value, str): + lower_value = value.lower() + return any(keyword in lower_value for keyword in SILICON_VLM_METADATA_KEYWORDS) + if isinstance(value, list): + return any(_contains_silicon_vlm_metadata(item) for item in value) + if isinstance(value, dict): + return any(_contains_silicon_vlm_metadata(item) for item in value.values()) + return False + + +def _is_silicon_vlm_model(model: Dict) -> bool: + if _contains_silicon_vlm_metadata(model): + return True + + model_id = str(model.get("id", "")).lower() + model_name = str(model.get("name", "")).lower() + searchable_text = f"{model_id} {model_name}" + if any(keyword in searchable_text for keyword in SILICON_VLM_MODEL_KEYWORDS): + return True + + return bool(re.search(r"glm-\d+(?:\.\d+)?v", searchable_text)) + + +def _is_silicon_omni_model(model: Dict) -> bool: + model_id = str(model.get("id", "")).lower() + model_name = str(model.get("name", "")).lower() + return "omni" in f"{model_id} {model_name}" + + class SiliconModelProvider(AbstractModelProvider): """Concrete implementation for SiliconFlow provider.""" @@ -25,12 +82,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: headers = {"Authorization": f"Bearer {model_api_key}"} + provider_model_type = "vlm" if model_type in ("vlm2", "vlm3") else model_type + # Choose endpoint by model type - if model_type in ("llm", "vlm"): + if provider_model_type in ("llm", "vlm"): silicon_url = f"{SILICON_GET_URL}?sub_type=chat" - elif model_type in ("embedding", "multi_embedding"): + elif provider_model_type in ("embedding", "multi_embedding"): silicon_url = f"{SILICON_GET_URL}?sub_type=embedding" - elif model_type == "rerank": + elif provider_model_type == "rerank": silicon_url = f"{SILICON_GET_URL}?sub_type=reranker" else: silicon_url = SILICON_GET_URL @@ -40,17 +99,22 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: response.raise_for_status() model_list: List[Dict] = response.json()["data"] + if model_type == "vlm3": + model_list = [item for item in model_list if _is_silicon_omni_model(item)] + elif provider_model_type == "vlm": + model_list = [item for item in model_list if _is_silicon_vlm_model(item)] + # Annotate models with canonical fields expected downstream - if model_type in ("llm", "vlm"): + if provider_model_type in ("llm", "vlm"): for item in model_list: item["model_tag"] = "chat" item["model_type"] = model_type item["max_tokens"] = DEFAULT_LLM_MAX_TOKENS - elif model_type in ("embedding", "multi_embedding"): + elif provider_model_type in ("embedding", "multi_embedding"): for item in model_list: item["model_tag"] = "embedding" item["model_type"] = model_type - elif model_type == "rerank": + elif provider_model_type == "rerank": for item in model_list: item["model_tag"] = "rerank" item["model_type"] = model_type diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 5e5229ff6..d3f3d513b 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -38,7 +38,7 @@ from services.file_management_service import get_llm_model, validate_urls_access from services.vectordatabase_service import get_embedding_model_by_index_name, get_rerank_model from database.client import minio_client -from services.image_service import get_vlm_model +from services.image_service import get_video_understanding_model, get_vlm_model from nexent.monitor import set_monitoring_context, set_monitoring_operation from services.vectordatabase_service import get_vector_db_core from utils.langchain_utils import discover_langchain_modules @@ -765,6 +765,7 @@ def _validate_local_tool( if not tenant_id or not user_id: raise ToolExecutionException( f"Tenant ID and User ID are required for {tool_name} validation") + # get_vlm_model reads the first multimodal slot, now shown as image understanding. image_to_text_model = get_vlm_model(tenant_id=tenant_id) vlm_display_name = getattr( image_to_text_model, 'display_name', None) @@ -778,6 +779,23 @@ def _validate_local_tool( 'validate_url_access': lambda urls: validate_urls_access(urls, user_id) } tool_instance = tool_class(**params) + elif tool_name in ["analyze_audio", "analyze_video"]: + if not tenant_id or not user_id: + raise ToolExecutionException( + f"Tenant ID and User ID are required for {tool_name} validation") + video_understanding_model = get_video_understanding_model(tenant_id=tenant_id) + model_display_name = getattr( + video_understanding_model, 'display_name', None) + set_monitoring_context(tenant_id=tenant_id) + set_monitoring_operation( + "tool_validation", display_name=model_display_name) + params = { + **instantiation_params, + 'vlm_model': video_understanding_model, + 'storage_client': minio_client, + 'validate_url_access': lambda urls: validate_urls_access(urls, user_id) + } + tool_instance = tool_class(**params) elif tool_name == "analyze_text_file": if not tenant_id or not user_id: raise ToolExecutionException( diff --git a/docker/.env.bak b/docker/.env.bak deleted file mode 100644 index 24b53751b..000000000 --- a/docker/.env.bak +++ /dev/null @@ -1,168 +0,0 @@ -# ===== Necessary Configs (Necessary till now, will be migrated to frontend page) ===== - -# Voice Service Config -APPID=app_id -TOKEN=token - -# ===== Non-essential Configs (Modify if you know what you are doing) ===== - -CLUSTER=volcano_tts -VOICE_TYPE=zh_male_jieshuonansheng_mars_bigtts -SPEED_RATIO=1.3 - -# ===== Proxy Configuration (Optional) ===== - -# HTTP_PROXY=http://proxy-server:port -# HTTPS_PROXY=http://proxy-server:port -# NO_PROXY=localhost,127.0.0.1 - -# ===== Backend Configuration (No need to modify at all) ===== - -# Model Path Config -CLIP_MODEL_PATH=/opt/models/clip-vit-base-patch32 -NLTK_DATA=/opt/models/nltk_data - -# Elasticsearch Service -ELASTICSEARCH_HOST=http://nexent-elasticsearch:9200 -ELASTIC_PASSWORD=nexent@2025 - -# Elasticsearch Memory Configuration -ES_JAVA_OPTS="-Xms2g -Xmx2g" - -# Elasticsearch Disk Watermark Configuration -ES_DISK_WATERMARK_LOW=85% -ES_DISK_WATERMARK_HIGH=90% -ES_DISK_WATERMARK_FLOOD_STAGE=95% - -# Main Services -# Config service (port 5010) - Main API service for config operations -CONFIG_SERVICE_URL=http://nexent-config:5010 -ELASTICSEARCH_SERVICE=http://nexent-config:5010/api - -# Runtime service (port 5014) - Runtime execution service for agent operations -RUNTIME_SERVICE_URL=http://nexent-runtime:5014 - -# MCP service (port 5011) - MCP protocol service -NEXENT_MCP_SERVER=http://nexent-mcp:5011 -MCP_MANAGEMENT_API=http://nexent-mcp:5015 - -# Data process service (port 5012) - Data processing service -DATA_PROCESS_SERVICE=http://nexent-data-process:5012/api - -# Northbound service (port 5013) - Northbound API service -NORTHBOUND_API_SERVER=http://nexent-northbound:5013/api - -# Postgres Config -POSTGRES_HOST=nexent-postgresql -POSTGRES_USER=root -NEXENT_POSTGRES_PASSWORD=nexent@4321 -POSTGRES_DB=nexent -POSTGRES_PORT=5432 - -# Minio Config -MINIO_ENDPOINT=http://nexent-minio:9000 -MINIO_ROOT_USER=nexent -MINIO_ROOT_PASSWORD=nexent@4321 -MINIO_REGION=cn-north-1 -MINIO_DEFAULT_BUCKET=nexent - -# Redis Config -REDIS_URL=redis://redis:6379/0 -REDIS_BACKEND_URL=redis://redis:6379/1 - -# Model Engine Config -MODEL_ENGINE_ENABLED=false - -# Supabase Config -DASHBOARD_USERNAME=supabase -DASHBOARD_PASSWORD=Huawei123 - -# Supabase db Config -SUPABASE_POSTGRES_PASSWORD=Huawei123 -SUPABASE_POSTGRES_HOST=db -SUPABASE_POSTGRES_DB=supabase -SUPABASE_POSTGRES_PORT=5436 - -# Supabase Auth Config -SITE_URL=http://localhost:3011 -SUPABASE_URL=http://supabase-kong-mini:8000 -API_EXTERNAL_URL=http://supabase-kong-mini:8000 -DISABLE_SIGNUP=false -JWT_EXPIRY=3600 -DEBUG_JWT_EXPIRE_SECONDS=0 - -# Supabase Configuration -ENABLE_EMAIL_SIGNUP=true -ENABLE_EMAIL_AUTOCONFIRM=true -ENABLE_ANONYMOUS_USERS=false - -# Supabase Phone Config -ENABLE_PHONE_SIGNUP=false -ENABLE_PHONE_AUTOCONFIRM=false - -MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify" -MAILER_URLPATHS_INVITE="/auth/v1/verify" -MAILER_URLPATHS_RECOVERY="/auth/v1/verify" -MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify" - -INVITE_CODE=nexent2025 - -# Terminal Tool SSH Key Path -SSH_PRIVATE_KEY_PATH=/path/to/openssh-server/ssh-keys/openssh_server_key - -# ===== Data Processing Service Configuration ===== - -# Redis Port -REDIS_PORT=6379 - -# Flower Monitoring -FLOWER_PORT=5555 - -# Ray Configuration -RAY_ACTOR_NUM_CPUS=2 -RAY_DASHBOARD_PORT=8265 -RAY_DASHBOARD_HOST=0.0.0.0 -RAY_NUM_CPUS=4 -RAY_OBJECT_STORE_MEMORY_GB=0.25 -RAY_TEMP_DIR=/tmp/ray -RAY_LOG_LEVEL=INFO - -# Service Control Flags -DISABLE_RAY_DASHBOARD=true -DISABLE_CELERY_FLOWER=true -DOCKER_ENVIRONMENT=false -ENABLE_UPLOAD_IMAGE=false - -# Celery Configuration -CELERY_WORKER_PREFETCH_MULTIPLIER=1 -CELERY_TASK_TIME_LIMIT=3600 -ELASTICSEARCH_REQUEST_TIMEOUT=30 - -# Worker Configuration -QUEUES=process_q,forward_q -WORKER_NAME= -WORKER_CONCURRENCY=4 - -# Skills Configuration -SKILLS_PATH=/mnt/nexent/skills - -# Telemetry and Monitoring Configuration -ENABLE_TELEMETRY=false -SERVICE_NAME=nexent-backend -JAEGER_ENDPOINT=http://localhost:14268/api/traces -PROMETHEUS_PORT=8000 -TELEMETRY_SAMPLE_RATE=1.0 -LLM_SLOW_REQUEST_THRESHOLD_SECONDS=5.0 -LLM_SLOW_TOKEN_RATE_THRESHOLD=10.0 - -# Market Backend Address -MARKET_BACKEND=http://60.204.251.153:8010 -DEPLOYMENT_VERSION="speed" -# Root dir -ROOT_DIR="/c/Users/18270/nexent-data" -TERMINAL_MOUNT_DIR="/opt/terminal" -SSH_USERNAME="root" -SSH_PASSWORD="731215" -NEXENT_MCP_DOCKER_IMAGE="ccr.ccs.tencentyun.com/nexent-hub/nexent-mcp:v2.0.1" -MINIO_ACCESS_KEY="72c31cb5b521511cea652723" -MINIO_SECRET_KEY="m5gcSuKzZnp84CqmG7z5VKnd2C+H5U3PSr7eoJeygmI=" diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index 909592345..7dbba48d7 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -34,11 +34,17 @@ const TOOLS_REQUIRING_EMBEDDING = [ "knowledge_base_search", ]; -// Tool types that require VLM model -const TOOLS_REQUIRING_VLM = [ +// Tool types that require the image understanding model +const TOOLS_REQUIRING_IMAGE_UNDERSTANDING = [ "analyze_image", ]; +// Tool types that require the video understanding model +const TOOLS_REQUIRING_VIDEO_UNDERSTANDING = [ + "analyze_audio", + "analyze_video", +]; + function getToolKbType( toolName: string ): "knowledge_base_search" | "dify_search" | "datamate_search" | "idata_search" | "haotian_search" | null { @@ -53,9 +59,18 @@ function getToolKbType( /** * Check if a tool requires VLM model but VLM is not available */ -function isToolDisabledDueToVlm(toolName: string, vlmAvailable: boolean): boolean { - if (!TOOLS_REQUIRING_VLM.includes(toolName)) return false; - return !vlmAvailable; +function isToolDisabledDueToVlm( + toolName: string, + imageUnderstandingAvailable: boolean, + videoUnderstandingAvailable: boolean +): boolean { + if (TOOLS_REQUIRING_IMAGE_UNDERSTANDING.includes(toolName)) { + return !imageUnderstandingAvailable; + } + if (TOOLS_REQUIRING_VIDEO_UNDERSTANDING.includes(toolName)) { + return !videoUnderstandingAvailable; + } + return false; } /** @@ -102,7 +117,11 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - const { isVlmAvailable, isEmbeddingAvailable } = useConfig(); + const { + isImageUnderstandingAvailable, + isVideoUnderstandingAvailable, + isEmbeddingAvailable, + } = useConfig(); // Prefetch knowledge bases for KB tools const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases(); @@ -362,7 +381,11 @@ export default function ToolManagement({ const isSelected = originalSelectedToolIdsSet.has( tool.id ); - const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); + const isDisabledDueToVlm = isToolDisabledDueToVlm( + tool.name, + isImageUnderstandingAvailable, + isVideoUnderstandingAvailable + ); const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding @@ -467,7 +490,11 @@ export default function ToolManagement({ > {group.tools.map((tool) => { const isSelected = originalSelectedToolIdsSet.has(tool.id); - const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); + const isDisabledDueToVlm = isToolDisabledDueToVlm( + tool.name, + isImageUnderstandingAvailable, + isVideoUnderstandingAvailable + ); const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding diff --git a/frontend/app/[locale]/chat/components/chatAttachment.tsx b/frontend/app/[locale]/chat/components/chatAttachment.tsx index 5c9da8ec9..d12e939cd 100644 --- a/frontend/app/[locale]/chat/components/chatAttachment.tsx +++ b/frontend/app/[locale]/chat/components/chatAttachment.tsx @@ -87,6 +87,14 @@ const getFileIcon = (name: string, contentType?: string) => { return ; } + // Audio and video files are uploaded as regular attachments for multimodal tools. + if (chatConfig.fileIcons.audio.includes(extension) || fileType.startsWith("audio/")) { + return ; + } + if (chatConfig.fileIcons.video.includes(extension) || fileType.startsWith("video/")) { + return ; + } + // Compressed file if (chatConfig.fileIcons.compressed.includes(extension)) { return ; @@ -230,4 +238,4 @@ export function ChatAttachment({ )} ); -} \ No newline at end of file +} diff --git a/frontend/app/[locale]/chat/components/chatInput.tsx b/frontend/app/[locale]/chat/components/chatInput.tsx index 8de0d17eb..bcfc86f6b 100644 --- a/frontend/app/[locale]/chat/components/chatInput.tsx +++ b/frontend/app/[locale]/chat/components/chatInput.tsx @@ -96,10 +96,24 @@ const getFileIcon = (file: File) => { return ; } + if (chatConfig.fileIcons.audio.includes(extension) || fileType.startsWith("audio/")) { + return ; + } + + if (chatConfig.fileIcons.video.includes(extension) || fileType.startsWith("video/")) { + return ; + } + // Default file icon return ; }; +const isSupportedMediaFile = (extension: string, fileType: string) => + fileType.startsWith("audio/") || + fileType.startsWith("video/") || + chatConfig.audioExtensions.includes(extension) || + chatConfig.videoExtensions.includes(extension); + // File limit constants from config const MAX_FILE_COUNT = chatConfig.maxFileCount; const MAX_FILE_SIZE = chatConfig.maxFileSize; @@ -617,8 +631,9 @@ export function ChatInput({ chatConfig.supportedTextExtensions.includes(extension) || file.type === "text/csv" || file.type === "text/plain"; + const isMedia = isSupportedMediaFile(extension, file.type); - if (isImage || isDocument || isSupportedTextFile) { + if (isImage || isDocument || isSupportedTextFile || isMedia) { // Create a preview URL for images const previewUrl = isImage ? URL.createObjectURL(file) : undefined; @@ -899,7 +914,7 @@ export function ChatInput({ id="file-upload-regular" className="hidden" onChange={handleFileUpload} - accept={`image/*,${Object.values(chatConfig.fileIcons).flat().map(ext => `.${ext}`).join(',')}`} + accept={`image/*,audio/*,video/*,${Object.values(chatConfig.fileIcons).flat().map(ext => `.${ext}`).join(',')}`} multiple /> @@ -1026,8 +1041,9 @@ export function ChatInput({ chatConfig.supportedTextExtensions.includes(extension) || fileType === "text/csv" || fileType === "text/plain"; + const isMedia = isSupportedMediaFile(extension, fileType); - return !(isImage || isDocument || isSupportedTextFile); + return !(isImage || isDocument || isSupportedTextFile || isMedia); }); // Regular mode, keep the original rendering logic diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx index 6e0de48b5..0f3c99715 100644 --- a/frontend/app/[locale]/chat/internal/chatInterface.tsx +++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx @@ -38,7 +38,7 @@ import { extractAssistantMsgFromResponse, } from "@/lib/chatMessageExtractor"; -import { Layout } from "antd"; +import { Layout, message } from "antd"; import log from "@/lib/logger"; const stepIdCounter = { current: 0 }; @@ -268,9 +268,23 @@ export function ChatInterface() { // Use preprocessing function to upload attachments const uploadResult = await uploadAttachments(attachments, t); + if (uploadResult.error) { + message.error(`${t("chatPreprocess.fileUploadFailed")} ${uploadResult.error}`); + setIsLoading(false); + return; + } uploadedFileUrls = uploadResult.uploadedFileUrls; objectNames = uploadResult.objectNames; // Get object name mapping presignedUrls = uploadResult.presignedUrls; // Get presigned URLs for external access + + const missingUploads = attachments.filter( + (attachment) => !uploadedFileUrls[attachment.file.name] || !objectNames[attachment.file.name] + ); + if (missingUploads.length > 0) { + message.error(`${t("chatPreprocess.fileUploadFailed")} ${missingUploads.map((item) => item.file.name).join(", ")}`); + setIsLoading(false); + return; + } } // Use preprocessing function to create message attachments diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 11391c133..85737c251 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -66,6 +66,23 @@ const DEFAULT_FORM_STATE = { accessToken: "", }; +const resolveConnectivityModelType = (type: ModelType): ModelType => + type === MODEL_TYPES.VLM2 || type === MODEL_TYPES.VLM3 + ? (MODEL_TYPES.VLM as ModelType) + : type; + +const resolveConfigKey = (type: ModelType): string => + type; + +const isVlmConfigType = (type: ModelType): boolean => + type === MODEL_TYPES.VLM || type === MODEL_TYPES.VLM2 || type === MODEL_TYPES.VLM3; + +const emptyModelConfig = { + modelName: "", + displayName: "", + apiConfig: { apiKey: "", modelUrl: "" }, +}; + // Connectivity status type comes from utils // Helper function to translate error messages from backend @@ -196,7 +213,7 @@ export const ModelAddDialog = ({ }: ModelAddDialogProps) => { const { t } = useTranslation(); const { message } = App.useApp(); - const { updateModelConfig, saveConfig } = useConfig(); + const { modelConfig: currentModelConfig, updateModelConfig, saveConfig } = useConfig(); // Parse backend error message and return i18n key with params const parseModelError = ( @@ -475,7 +492,7 @@ export const ModelAddDialog = ({ const modelType = form.type === MODEL_TYPES.EMBEDDING && form.isMultimodal ? (MODEL_TYPES.MULTI_EMBEDDING as ModelType) - : form.type; + : resolveConnectivityModelType(form.type); let connectivity = false; @@ -615,6 +632,32 @@ export const ModelAddDialog = ({ }); } + if (isVlmConfigType(form.type) && enabledModels.length > 0) { + const selectedModel = enabledModels[0]; + const selectedDisplayName = selectedModel.displayName || selectedModel.id || ""; + const configKey = resolveConfigKey(form.type); + const vlmConfigUpdate: any = { + [configKey]: { + modelName: selectedModel.id || selectedModel.model_name || "", + displayName: selectedDisplayName, + apiConfig: { + apiKey: form.apiKey, + modelUrl: "", + }, + }, + }; + for (const key of [MODEL_TYPES.VLM, MODEL_TYPES.VLM2, MODEL_TYPES.VLM3]) { + if ( + key !== configKey && + currentModelConfig?.[key]?.displayName === selectedDisplayName + ) { + vlmConfigUpdate[key] = emptyModelConfig; + } + } + updateModelConfig(vlmConfigUpdate); + await persistModelConfig(); + } + // Reset form state and close dialog on success resetForm(); handleClose(); @@ -772,6 +815,7 @@ export const ModelAddDialog = ({ // Update the local storage according to the model type let configUpdate: any = {}; + const configKey = resolveConfigKey(form.type); switch (modelType) { case MODEL_TYPES.LLM: @@ -784,7 +828,17 @@ export const ModelAddDialog = ({ configUpdate = { multiEmbedding: modelConfig }; break; case MODEL_TYPES.VLM: - configUpdate = { vlm: modelConfig }; + case MODEL_TYPES.VLM2: + case MODEL_TYPES.VLM3: + configUpdate = { [configKey]: modelConfig }; + for (const key of [MODEL_TYPES.VLM, MODEL_TYPES.VLM2, MODEL_TYPES.VLM3]) { + if ( + key !== configKey && + currentModelConfig?.[key]?.displayName === modelConfig.displayName + ) { + configUpdate[key] = emptyModelConfig; + } + } break; case MODEL_TYPES.RERANK: configUpdate = { rerank: modelConfig }; @@ -926,7 +980,15 @@ export const ModelAddDialog = ({ - + + + diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index ad3cf0391..96eebc2e4 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -101,6 +101,8 @@ export const ModelDeleteDialog = ({ border: "border-purple-100", }; case MODEL_TYPES.VLM: + case MODEL_TYPES.VLM2: + case MODEL_TYPES.VLM3: return { bg: "bg-yellow-50", text: "text-yellow-600", @@ -143,6 +145,8 @@ export const ModelDeleteDialog = ({ case MODEL_TYPES.TTS: return "🔊"; case MODEL_TYPES.VLM: + case MODEL_TYPES.VLM2: + case MODEL_TYPES.VLM3: return "👁️"; default: return "⚙️"; @@ -167,6 +171,10 @@ export const ModelDeleteDialog = ({ return t("model.type.tts"); case MODEL_TYPES.VLM: return t("model.type.vlm"); + case MODEL_TYPES.VLM2: + return `${t("model.type.vlm")}2`; + case MODEL_TYPES.VLM3: + return `${t("model.type.vlm")}3`; default: return t("model.type.unknown"); } @@ -346,7 +354,10 @@ export const ModelDeleteDialog = ({ if (cfgUrl && cfgUrl.trim() !== "") return cfgUrl; } if (type === MODEL_TYPES.VLM) { - const cfgUrl = modelConfig?.vlm?.apiConfig?.modelUrl; + const cfgUrl = + modelConfig?.vlm?.apiConfig?.modelUrl || + modelConfig?.vlm2?.apiConfig?.modelUrl || + modelConfig?.vlm3?.apiConfig?.modelUrl; if (cfgUrl && cfgUrl.trim() !== "") return cfgUrl; } if (type === MODEL_TYPES.LLM) { @@ -503,6 +514,22 @@ export const ModelDeleteDialog = ({ }; } + if (modelConfig.vlm2?.displayName === displayName) { + configUpdates.vlm2 = { + modelName: "", + displayName: "", + apiConfig: { apiKey: "", modelUrl: "" }, + }; + } + + if (modelConfig.vlm3?.displayName === displayName) { + configUpdates.vlm3 = { + modelName: "", + displayName: "", + apiConfig: { apiKey: "", modelUrl: "" }, + }; + } + if (modelConfig.stt.displayName === displayName) { configUpdates.stt = { modelName: "", displayName: "" }; } @@ -1028,6 +1055,8 @@ export const ModelDeleteDialog = ({ MODEL_TYPES.MULTI_EMBEDDING, MODEL_TYPES.RERANK, MODEL_TYPES.VLM, + MODEL_TYPES.VLM2, + MODEL_TYPES.VLM3, MODEL_TYPES.STT, MODEL_TYPES.TTS, ] as ModelType[] diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index 3114c5535..cdac265f5 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -87,6 +87,10 @@ export const ModelEditDialog = ({ form.type === MODEL_TYPES.EMBEDDING || form.type === MODEL_TYPES.MULTI_EMBEDDING; const isRerankModel = form.type === MODEL_TYPES.RERANK; + const connectivityModelType = + form.type === MODEL_TYPES.VLM2 || form.type === MODEL_TYPES.VLM3 + ? (MODEL_TYPES.VLM as ModelType) + : form.type; const isFormValid = () => { return form.name.trim() !== "" && form.url.trim() !== ""; @@ -106,11 +110,9 @@ export const ModelEditDialog = ({ }); try { - const modelType = form.type as ModelType; - const config = { modelName: form.name, - modelType: modelType, + modelType: connectivityModelType, baseUrl: form.url, apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, maxTokens: @@ -205,6 +207,8 @@ export const ModelEditDialog = ({ embedding: MODEL_TYPES.EMBEDDING, multi_embedding: MODEL_TYPES.MULTI_EMBEDDING, vlm: MODEL_TYPES.VLM, + vlm2: MODEL_TYPES.VLM2, + vlm3: MODEL_TYPES.VLM3, rerank: MODEL_TYPES.RERANK, tts: MODEL_TYPES.TTS, stt: MODEL_TYPES.STT, @@ -481,4 +485,4 @@ export const ProviderConfigEditDialog = ({ ) -} \ No newline at end of file +} diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index 07eee5c06..36fcdbb31 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -56,7 +56,11 @@ const getModelData = (t: any) => ({ }, multimodal: { title: t("modelConfig.category.multimodal"), - options: [{ id: MODEL_TYPES.VLM, name: t("modelConfig.option.vlmModel") }], + options: [ + { id: MODEL_TYPES.VLM, name: t("modelConfig.option.imageUnderstandingModel") }, + { id: MODEL_TYPES.VLM2, name: t("modelConfig.option.imageGenerationModel") }, + { id: MODEL_TYPES.VLM3, name: t("modelConfig.option.videoUnderstandingModel") }, + ], }, voice: { title: t("modelConfig.category.voice"), @@ -142,7 +146,7 @@ export const ModelConfigSection = forwardRef< llm: { main: "" }, embedding: { embedding: "", multi_embedding: "" }, reranker: { reranker: "" }, - multimodal: { vlm: "" }, + multimodal: { vlm: "", vlm2: "", vlm3: "" }, voice: { tts: "", stt: "" }, }); @@ -284,11 +288,23 @@ export const ModelConfigSection = forwardRef< : true; const vlm = modelConfig.vlm.displayName; + const vlm2 = modelConfig.vlm2?.displayName || ""; + const vlm3 = modelConfig.vlm3?.displayName || ""; const vlmExists = vlm ? allModels.some( (m) => m.displayName === vlm && m.type === MODEL_TYPES.VLM ) : true; + const vlm2Exists = vlm2 + ? allModels.some( + (m) => m.displayName === vlm2 && m.type === MODEL_TYPES.VLM2 + ) + : true; + const vlm3Exists = vlm3 + ? allModels.some( + (m) => m.displayName === vlm3 && m.type === MODEL_TYPES.VLM3 + ) + : true; const stt = modelConfig.stt.displayName; const sttExists = stt @@ -318,6 +334,8 @@ export const ModelConfigSection = forwardRef< }, multimodal: { vlm: vlmExists ? vlm : "", + vlm2: vlm2Exists ? vlm2 : "", + vlm3: vlm3Exists ? vlm3 : "", }, voice: { tts: ttsExists ? tts : "", @@ -363,6 +381,14 @@ export const ModelConfigSection = forwardRef< configUpdates.vlm = { modelName: "", displayName: "" }; } + if (!vlm2Exists && vlm2) { + configUpdates.vlm2 = { modelName: "", displayName: "" }; + } + + if (!vlm3Exists && vlm3) { + configUpdates.vlm3 = { modelName: "", displayName: "" }; + } + if (!sttExists && stt) { configUpdates.stt = { modelName: "", displayName: "" }; } @@ -385,6 +411,8 @@ export const ModelConfigSection = forwardRef< !!modelConfig.multiEmbedding.modelName || !!modelConfig.rerank.modelName || !!modelConfig.vlm.modelName || + !!modelConfig.vlm2?.modelName || + !!modelConfig.vlm3?.modelName || !!modelConfig.tts.modelName || !!modelConfig.stt.modelName; @@ -441,11 +469,13 @@ export const ModelConfigSection = forwardRef< const hasEmbedding = !!modelConfig.embedding.modelName; const hasReranker = !!modelConfig.rerank.modelName; const hasVlm = !!modelConfig.vlm.modelName; + const hasVlm2 = !!modelConfig.vlm2?.modelName; + const hasVlm3 = !!modelConfig.vlm3?.modelName; const hasTts = !!modelConfig.tts.modelName; const hasStt = !!modelConfig.stt.modelName; hasSelectedModels = - hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasTts || hasStt; + hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasVlm2 || hasVlm3 || hasTts || hasStt; if (hasSelectedModels) { currentSelectedModels.llm.main = modelConfig.llm.modelName; @@ -455,6 +485,8 @@ export const ModelConfigSection = forwardRef< modelConfig.multiEmbedding.modelName || ""; currentSelectedModels.reranker.reranker = modelConfig.rerank.modelName; currentSelectedModels.multimodal.vlm = modelConfig.vlm.modelName; + currentSelectedModels.multimodal.vlm2 = modelConfig.vlm2?.modelName || ""; + currentSelectedModels.multimodal.vlm3 = modelConfig.vlm3?.modelName || ""; currentSelectedModels.voice.tts = modelConfig.tts.modelName; currentSelectedModels.voice.stt = modelConfig.stt.modelName; } else { @@ -492,7 +524,7 @@ export const ModelConfigSection = forwardRef< } else if (category === "reranker") { modelType = MODEL_TYPES.RERANK; } else if (category === "multimodal") { - modelType = MODEL_TYPES.VLM; + modelType = optionId as ModelType; } else if (category === MODEL_TYPES.EMBEDDING) { modelType = optionId === MODEL_TYPES.MULTI_EMBEDDING @@ -654,7 +686,7 @@ export const ModelConfigSection = forwardRef< } else if (category === "reranker") { modelType = MODEL_TYPES.RERANK; } else if (category === "multimodal") { - modelType = MODEL_TYPES.VLM; + modelType = option as ModelType; } else if (category === MODEL_TYPES.EMBEDDING) { modelType = option === MODEL_TYPES.MULTI_EMBEDDING @@ -679,7 +711,7 @@ export const ModelConfigSection = forwardRef< ) { configKey = "multiEmbedding"; } else if (category === "multimodal") { - configKey = MODEL_TYPES.VLM; + configKey = option; } else if (category === "reranker") { configKey = MODEL_TYPES.RERANK; } else if (category === "voice" && option === "tts") { @@ -1005,7 +1037,7 @@ export const ModelConfigSection = forwardRef< ? MODEL_TYPES.TTS : MODEL_TYPES.STT : key === "multimodal" - ? MODEL_TYPES.VLM + ? (option.id as ModelType) : key === MODEL_TYPES.EMBEDDING && option.id === MODEL_TYPES.MULTI_EMBEDDING ? MODEL_TYPES.MULTI_EMBEDDING diff --git a/frontend/const/chatConfig.ts b/frontend/const/chatConfig.ts index fc0dbe6d5..27b3b887d 100644 --- a/frontend/const/chatConfig.ts +++ b/frontend/const/chatConfig.ts @@ -38,6 +38,12 @@ export const chatConfig = { // Supported document file extensions documentExtensions: ["pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "epub", "html", "xml"], + + // Supported audio file extensions + audioExtensions: ["mp3", "wav", "m4a", "aac", "ogg", "oga", "flac", "webm"], + + // Supported video file extensions + videoExtensions: ["mp4", "mov", "m4v", "avi", "mkv", "webm", "wmv", "flv"], // Supported text document extensions supportedTextExtensions: ["md", "markdown", "txt", "csv", "json"], @@ -73,6 +79,12 @@ export const chatConfig = { // Compressed file compressed: ["zip", "rar", "7z", "tar", "gz"], + + // Audio files + audio: ["mp3", "wav", "m4a", "aac", "ogg", "oga", "flac", "webm"], + + // Video files + video: ["mp4", "mov", "m4v", "avi", "mkv", "wmv", "flv"], }, // File preview type constants @@ -148,4 +160,4 @@ export const MESSAGE_ROLES = { USER: "user" as const, ASSISTANT: "assistant" as const, SYSTEM: "system" as const, -} as const; \ No newline at end of file +} as const; diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts index 9bdc5a4a8..b7762ace0 100644 --- a/frontend/const/modelConfig.ts +++ b/frontend/const/modelConfig.ts @@ -7,6 +7,8 @@ export const MODEL_TYPES = { STT: "stt", TTS: "tts", VLM: "vlm", + VLM2: "vlm2", + VLM3: "vlm3", } as const; // Model source constants diff --git a/frontend/hooks/model/useDashscopeModelList.ts b/frontend/hooks/model/useDashscopeModelList.ts index b44348fe5..5d1035e8a 100644 --- a/frontend/hooks/model/useDashscopeModelList.ts +++ b/frontend/hooks/model/useDashscopeModelList.ts @@ -39,7 +39,9 @@ export const useDashscopeModelList = ({ const modelType = form.type === "embedding" && form.isMultimodal ? ("multi_embedding" as ModelType) - : form.type; + : form.type === "vlm2" || form.type === "vlm3" + ? ("vlm" as ModelType) + : form.type; try { // Use manage interface if tenantId is provided (for super admin) diff --git a/frontend/hooks/model/useTokenponyModelList.ts b/frontend/hooks/model/useTokenponyModelList.ts index 0a7e23581..0c502a404 100644 --- a/frontend/hooks/model/useTokenponyModelList.ts +++ b/frontend/hooks/model/useTokenponyModelList.ts @@ -39,7 +39,9 @@ export const useTokenPonyModelList = ({ const modelType = form.type === "embedding" && form.isMultimodal ? ("multi_embedding" as ModelType) - : form.type; + : form.type === "vlm2" || form.type === "vlm3" + ? ("vlm" as ModelType) + : form.type; try { // Use manage interface if tenantId is provided (for super admin) diff --git a/frontend/hooks/useConfig.ts b/frontend/hooks/useConfig.ts index 8d4c4ccea..4d095e681 100644 --- a/frontend/hooks/useConfig.ts +++ b/frontend/hooks/useConfig.ts @@ -76,6 +76,22 @@ const defaultConfig: GlobalConfig = { modelUrl: "", }, }, + vlm2: { + modelName: "", + displayName: "", + apiConfig: { + apiKey: "", + modelUrl: "", + }, + }, + vlm3: { + modelName: "", + displayName: "", + apiConfig: { + apiKey: "", + modelUrl: "", + }, + }, stt: { modelName: "", displayName: "", @@ -161,6 +177,8 @@ function transformBackendToFrontend(backendConfig: any): GlobalConfig { ), rerank: transformModelEntry(backendConfig.models.rerank), vlm: transformModelEntry(backendConfig.models.vlm), + vlm2: transformModelEntry(backendConfig.models.vlm2), + vlm3: transformModelEntry(backendConfig.models.vlm3), stt: transformVoiceModelEntry(backendConfig.models.stt), tts: transformVoiceModelEntry(backendConfig.models.tts), } @@ -195,7 +213,10 @@ function loadConfigFromStorage(): GlobalConfig | null { if (storedModelConfig) { try { - mergedConfig.models = JSON.parse(storedModelConfig); + mergedConfig.models = deepMerge( + mergedConfig.models, + JSON.parse(storedModelConfig) + ); } catch (error) { log.error("Failed to parse model config:", error); } @@ -285,7 +306,24 @@ export function useConfig() { const config: GlobalConfig = (query.data as GlobalConfig | undefined) ?? defaultConfig; // Whether config has selected a VLM model - const isVlmAvailable = !!(config?.models?.vlm?.modelName || config?.models?.vlm?.displayName); + const isVlmAvailable = !!( + config?.models?.vlm?.modelName || + config?.models?.vlm?.displayName || + config?.models?.vlm2?.modelName || + config?.models?.vlm2?.displayName || + config?.models?.vlm3?.modelName || + config?.models?.vlm3?.displayName + ); + + const isImageUnderstandingAvailable = !!( + config?.models?.vlm?.modelName || + config?.models?.vlm?.displayName + ); + + const isVideoUnderstandingAvailable = !!( + config?.models?.vlm3?.modelName || + config?.models?.vlm3?.displayName + ); // Whether config has selected an Embedding model const isEmbeddingAvailable = !!(config?.models?.embedding?.modelName || config?.models?.embedding?.displayName); @@ -368,6 +406,8 @@ export function useConfig() { appConfig: config?.app, modelConfig: config?.models, isVlmAvailable, + isImageUnderstandingAvailable, + isVideoUnderstandingAvailable, isEmbeddingAvailable, defaultLlmModelName, updateAppConfig, diff --git a/frontend/lib/chat/chatAttachmentUtils.ts b/frontend/lib/chat/chatAttachmentUtils.ts index c85615b4e..bff686ca1 100644 --- a/frontend/lib/chat/chatAttachmentUtils.ts +++ b/frontend/lib/chat/chatAttachmentUtils.ts @@ -69,6 +69,19 @@ export const uploadAttachments = async ( }); } + const failedResults = uploadResult.results.filter((result) => !result.success); + if (failedResults.length > 0 || uploadResult.success_count < attachments.length) { + const failedMessage = failedResults + .map((result) => `${result.file_name || "file"}: ${result.error || "Upload failed"}`) + .join("; "); + return { + uploadedFileUrls, + objectNames, + presignedUrls, + error: failedMessage || "Upload failed", + }; + } + return { uploadedFileUrls, objectNames, presignedUrls }; } catch (error) { log.error(t("chatPreprocess.fileUploadFailed"), error); diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 22c17c2ca..2b22c3156 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -1,4 +1,4 @@ -{ +{ "assistant.name": "Nexent", "mainPage.layout.title": "Nexent | AI Agents", @@ -811,6 +811,9 @@ "model.type.llm": "Large Language Model", "model.type.embedding": "Embedding Model", "model.type.vlm": "Vision Language Model", + "model.type.imageUnderstanding": "Image Understanding Model", + "model.type.imageGeneration": "Image Generation Model", + "model.type.videoUnderstanding": "Video Understanding Model", "model.type.rerank": "Rerank Model", "model.type.stt": "Speech-to-Text Model", "model.type.tts": "Text-to-Speech Model", @@ -879,6 +882,9 @@ "modelConfig.option.multiEmbeddingModel": "Multimodal Embedding Model", "modelConfig.option.rerankerModel": "Reranker Model", "modelConfig.option.vlmModel": "Vision Language Model", + "modelConfig.option.imageUnderstandingModel": "Image Understanding Model", + "modelConfig.option.imageGenerationModel": "Image Generation Model", + "modelConfig.option.videoUnderstandingModel": "Video Understanding Model", "modelConfig.option.ttsModel": "Text-to-Speech Model", "modelConfig.option.sttModel": "Speech-to-Text Model", "modelConfig.error.loadList": "Failed to load model list:", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 1cc83a802..b0e9d7ef1 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -1,4 +1,4 @@ -{ +{ "assistant.name": "Nexent", "mainPage.layout.title": "Nexent | 智能问答", @@ -811,6 +811,9 @@ "model.type.llm": "大语言模型", "model.type.embedding": "向量模型", "model.type.vlm": "视觉语言模型", + "model.type.imageUnderstanding": "图片理解模型", + "model.type.imageGeneration": "图片生成模型", + "model.type.videoUnderstanding": "视频理解模型", "model.type.rerank": "重排模型", "model.type.stt": "语音识别模型", "model.type.tts": "语音合成模型", @@ -880,6 +883,9 @@ "modelConfig.option.multiEmbeddingModel": "多模态向量模型", "modelConfig.option.rerankerModel": "重排模型", "modelConfig.option.vlmModel": "视觉语言模型", + "modelConfig.option.imageUnderstandingModel": "图片理解模型", + "modelConfig.option.imageGenerationModel": "图片生成模型", + "modelConfig.option.videoUnderstandingModel": "视频理解模型", "modelConfig.option.ttsModel": "语音合成模型", "modelConfig.option.sttModel": "语音识别模型", "modelConfig.error.loadList": "加载模型列表失败:", diff --git a/frontend/tailwind.config.ts b/frontend/tailwind.config.js similarity index 100% rename from frontend/tailwind.config.ts rename to frontend/tailwind.config.js diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index a9f918d71..1a1df2b98 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -31,6 +31,8 @@ export type ModelType = | "stt" | "tts" | "vlm" + | "vlm2" + | "vlm3" | "multi_embedding"; // Model option interface @@ -99,6 +101,8 @@ export interface ModelConfig { multiEmbedding: SingleModelConfig; rerank: SingleModelConfig; vlm: SingleModelConfig; + vlm2: SingleModelConfig; + vlm3: SingleModelConfig; stt: STTModelConfig; tts: TTSModelConfig; } diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 023c8348e..c0dc83efc 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -115,7 +115,7 @@ def create_local_tool(self, tool_config: ToolConfig): data_process_service_url=tool_config.metadata.get("data_process_service_url", []), validate_url_access=validate_url_access, **params) - elif class_name == "AnalyzeImageTool": + elif class_name in ["AnalyzeImageTool", "AnalyzeAudioTool", "AnalyzeVideoTool"]: # Extract validate_url_access from metadata if it's callable validate_url_access = tool_config.metadata.get("validate_url_access") if tool_config.metadata else None if validate_url_access is not None and not callable(validate_url_access): @@ -493,4 +493,4 @@ def _val_width(vals, extra_val=None): # Optional: write to local file with open("nexent_context_metrics.log", "a", encoding="utf-8") as f: - f.write("\n".join(lines) + "\n") \ No newline at end of file + f.write("\n".join(lines) + "\n") diff --git a/sdk/nexent/core/models/openai_vlm.py b/sdk/nexent/core/models/openai_vlm.py index 1babb0057..cbc7388d6 100644 --- a/sdk/nexent/core/models/openai_vlm.py +++ b/sdk/nexent/core/models/openai_vlm.py @@ -126,6 +126,47 @@ def prepare_image_message(self, image_input: Union[str, BinaryIO], system_prompt return messages + def prepare_media_message( + self, + media_input: Union[str, BinaryIO], + media_type: str, + content_type: str, + system_prompt: str) -> List[Dict[str, Any]]: + """ + Prepare an OpenAI-compatible multimodal message for audio or video inputs. + + Args: + media_input: Media file path or file stream object. + media_type: Either "audio" or "video". + content_type: MIME type for the data URL. + system_prompt: System prompt. + + Returns: + List[Dict[str, Any]]: Prepared message list. + """ + if media_type not in ("audio", "video"): + raise ValueError(f"Unsupported media type: {media_type}") + + base64_media = self.encode_image(media_input) + media_url_key = f"{media_type}_url" + media_config: Dict[str, Any] = {"url": f"data:{content_type};base64,{base64_media}"} + if media_type == "video": + media_config.update({"detail": "high", "max_frames": 16, "fps": 1}) + + messages = [ + { + "role": "user", + "content": [ + { + "type": media_url_key, + media_url_key: media_config + }, + {"type": "text", "text": system_prompt} + ] + } + ] + return messages + def analyze_image(self, image_input: Union[str, BinaryIO], system_prompt: str = "Please describe this picture concisely and carefully, within 200 words.", stream: bool = True, **kwargs) -> ChatMessage: @@ -144,3 +185,23 @@ def analyze_image(self, image_input: Union[str, BinaryIO], messages = self.prepare_image_message(image_input, system_prompt) # Call __call__ explicitly so instance-level mocks work in tests. return self.__call__(messages=messages, **kwargs) + + def analyze_audio( + self, + audio_input: Union[str, BinaryIO], + system_prompt: str = "Please analyze this audio carefully.", + content_type: str = "audio/mpeg", + **kwargs) -> ChatMessage: + """Analyze audio content using the configured multimodal model.""" + messages = self.prepare_media_message(audio_input, "audio", content_type, system_prompt) + return self.__call__(messages=messages, **kwargs) + + def analyze_video( + self, + video_input: Union[str, BinaryIO], + system_prompt: str = "Please analyze this video carefully.", + content_type: str = "video/mp4", + **kwargs) -> ChatMessage: + """Analyze video content using the configured multimodal model.""" + messages = self.prepare_media_message(video_input, "video", content_type, system_prompt) + return self.__call__(messages=messages, **kwargs) diff --git a/sdk/nexent/core/prompts/analyze_audio_en.yaml b/sdk/nexent/core/prompts/analyze_audio_en.yaml new file mode 100644 index 000000000..eee0bb060 --- /dev/null +++ b/sdk/nexent/core/prompts/analyze_audio_en.yaml @@ -0,0 +1,13 @@ +# Audio Understanding Prompt Templates + +system_prompt: |- + The user has asked a question: {{ query }}. Please analyze this audio from the perspective of answering this question, within 300 words. + + **Audio Analysis Requirements:** + 1. Focus on speech, sound events, tone, timing, and other audio content relevant to the user's question + 2. If speech is present, summarize or transcribe the key spoken content when possible + 3. Keep the answer concise and grounded in observable audio evidence + 4. Avoid guessing identities or facts that cannot be inferred from the audio + +user_prompt: | + Please listen to this audio and describe it from the perspective of answering the user's question. diff --git a/sdk/nexent/core/prompts/analyze_audio_zh.yaml b/sdk/nexent/core/prompts/analyze_audio_zh.yaml new file mode 100644 index 000000000..ae6f1fa0d --- /dev/null +++ b/sdk/nexent/core/prompts/analyze_audio_zh.yaml @@ -0,0 +1,13 @@ +# 音频理解 Prompt 模板 + +system_prompt: |- + 用户提出的问题是:{{ query }}。请从回答该问题的角度分析这段音频,控制在 300 字以内。 + + **音频分析要求:** + 1. 关注与用户问题相关的语音、声音事件、语气、节奏和其他音频内容 + 2. 如果包含人声,请尽可能总结或转写关键口语内容 + 3. 回答要简洁,并基于音频中可观察到的信息 + 4. 不要猜测无法从音频中判断的身份或事实 + +user_prompt: | + 请仔细聆听这段音频,并从回答用户问题的角度进行描述。 diff --git a/sdk/nexent/core/prompts/analyze_video_en.yaml b/sdk/nexent/core/prompts/analyze_video_en.yaml new file mode 100644 index 000000000..7834ca7f3 --- /dev/null +++ b/sdk/nexent/core/prompts/analyze_video_en.yaml @@ -0,0 +1,13 @@ +# Video Understanding Prompt Templates + +system_prompt: |- + The user has asked a question: {{ query }}. Please analyze this video from the perspective of answering this question, within 300 words. + + **Video Analysis Requirements:** + 1. Focus on scenes, actions, objects, people, visible text, and temporal changes relevant to the user's question + 2. Mention important audio cues only when they help answer the question + 3. Keep the answer concise, structured, and grounded in visible or audible evidence + 4. Avoid over-interpreting intent or facts that cannot be inferred from the video + +user_prompt: | + Please watch this video and describe it from the perspective of answering the user's question. diff --git a/sdk/nexent/core/prompts/analyze_video_zh.yaml b/sdk/nexent/core/prompts/analyze_video_zh.yaml new file mode 100644 index 000000000..e83a1676d --- /dev/null +++ b/sdk/nexent/core/prompts/analyze_video_zh.yaml @@ -0,0 +1,13 @@ +# 视频理解 Prompt 模板 + +system_prompt: |- + 用户提出的问题是:{{ query }}。请从回答该问题的角度分析这段视频,控制在 300 字以内。 + + **视频分析要求:** + 1. 关注与用户问题相关的场景、动作、物体、人物、可见文字和时间变化 + 2. 只有在有助于回答问题时,才补充重要的音频线索 + 3. 回答要简洁、有条理,并基于视频中可见或可听的信息 + 4. 不要过度推断无法从视频中判断的意图或事实 + +user_prompt: | + 请仔细观看这段视频,并从回答用户问题的角度进行描述。 diff --git a/sdk/nexent/core/tools/__init__.py b/sdk/nexent/core/tools/__init__.py index 851690f16..a640cb5ff 100644 --- a/sdk/nexent/core/tools/__init__.py +++ b/sdk/nexent/core/tools/__init__.py @@ -19,6 +19,8 @@ from .terminal_tool import TerminalTool from .analyze_text_file_tool import AnalyzeTextFileTool from .analyze_image_tool import AnalyzeImageTool +from .analyze_audio_tool import AnalyzeAudioTool +from .analyze_video_tool import AnalyzeVideoTool from .run_skill_script_tool import run_skill_script from .read_skill_md_tool import read_skill_md from .read_skill_config_tool import read_skill_config @@ -47,6 +49,8 @@ "TerminalTool", "AnalyzeTextFileTool", "AnalyzeImageTool", + "AnalyzeAudioTool", + "AnalyzeVideoTool", "run_skill_script", "read_skill_md", "read_skill_config" diff --git a/sdk/nexent/core/tools/analyze_audio_tool.py b/sdk/nexent/core/tools/analyze_audio_tool.py new file mode 100644 index 000000000..c7509a6c2 --- /dev/null +++ b/sdk/nexent/core/tools/analyze_audio_tool.py @@ -0,0 +1,169 @@ +""" +Analyze Audio Tool + +Analyze audio using the configured video understanding model. +Supports audio from S3, HTTP, and HTTPS URLs. +""" + +import logging +from io import BytesIO +from typing import List + +from jinja2 import StrictUndefined, Template +from pydantic import Field +from smolagents.tools import Tool + +from ...core.models import OpenAIVLModel +from ...core.utils.observer import MessageObserver, ProcessType +from ...core.utils.prompt_template_utils import get_prompt_template +from ...core.utils.tools_common_message import ToolCategory, ToolSign +from ...multi_modal.load_save_object import LoadSaveObjectManager +from ...multi_modal.utils import detect_content_type_from_bytes +from ...storage import MinIOStorageClient + +logger = logging.getLogger("analyze_audio_tool") + + +class AnalyzeAudioTool(Tool): + """Tool for understanding and analyzing audio using the video understanding model.""" + + name = "analyze_audio" + description = ( + "This tool uses the configured video understanding model to understand audio based on your query and then returns an audio analysis result.\n" + "It is used to understand and analyze multiple audio files, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " + "HTTP, and HTTPS URLs.\n" + "Use this tool when you want to retrieve information contained in audio and provide the audio URL and your query." + ) + description_zh = ( + "使用视频理解模型,根据你的提示词来理解音频,并返回音频分析结果。" + "可用于理解和分析多个音频文件,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + ) + + inputs = { + "audio_urls_list": { + "type": "array", + "description": "List of audio URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", + "description_zh": "列表形式输入音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。" + }, + "query": { + "type": "string", + "description": "User's question to guide the audio analysis", + "description_zh": "用户用于指导音频分析的问题" + } + } + + init_param_descriptions = { + "observer": {"description": "Message observer"}, + "vlm_model": {"description": "The video understanding model to use"}, + "storage_client": {"description": "Storage client for downloading files"}, + "validate_url_access": { + "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" + } + } + output_type = "array" + category = ToolCategory.MULTIMODAL.value + tool_sign = ToolSign.MULTIMODAL_OPERATION.value + + def __init__( + self, + observer: MessageObserver = Field( + description="Message observer", + default=None, + exclude=True), + vlm_model: OpenAIVLModel = Field( + description="The video understanding model to use", + default=None, + exclude=True), + storage_client: MinIOStorageClient = Field( + description="Storage client for downloading files from S3 URLs, HTTP URLs, and HTTPS URLs.", + default=None, + exclude=True), + validate_url_access: callable = Field( + description="Callback function to validate URL access permissions", + default=None, + exclude=True) + ): + super().__init__() + self.observer = observer + self.vlm_model = vlm_model + self.storage_client = storage_client + self._is_chinese = bool(observer and observer.lang == "zh") + + validate_callback = None + if validate_url_access is not None and callable(validate_url_access): + validate_callback = validate_url_access + self.mm = LoadSaveObjectManager( + storage_client=self.storage_client, + validate_url_access=validate_callback + ) + self.forward = self.mm.load_object( + input_names=["audio_urls_list"])(self._forward_impl) + + self.running_prompt_zh = "正在分析音频..." + self.running_prompt_en = "Analyzing audio..." + + def _validate_audio_capable_model(self) -> None: + """Fail early for SiliconFlow models that are known not to accept audio input.""" + client_kwargs = getattr(self.vlm_model, "client_kwargs", {}) or {} + base_url = client_kwargs.get("base_url", "") if isinstance(client_kwargs, dict) else "" + model_id = str(getattr(self.vlm_model, "model_id", "") or "") + + if "siliconflow" in str(base_url).lower() and model_id and "omni" not in model_id.lower(): + raise ValueError( + "The selected video understanding model does not support audio input on SiliconFlow. " + "Please choose a Qwen3-Omni model for analyze_audio." + ) + + def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: + """Analyze audio files and return one result per audio input.""" + if self.vlm_model is None: + error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试" + error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again." + error_msg = error_msg_zh if self._is_chinese else error_msg_en + logger.error(error_msg) + raise Exception(error_msg) + self._validate_audio_capable_model() + + if self.observer: + running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en + self.observer.add_message("", ProcessType.TOOL, running_prompt) + + if audio_urls_list is None: + raise ValueError("audio_urls cannot be None") + if not isinstance(audio_urls_list, list): + raise ValueError("audio_urls must be a list of bytes") + if not audio_urls_list: + raise ValueError("audio_urls must contain at least one audio file") + + language = self.observer.lang if self.observer else "en" + prompts = get_prompt_template( + template_type='analyze_audio', language=language) + system_prompt = Template( + prompts['system_prompt'], undefined=StrictUndefined).render({'query': query}) + + try: + analysis_results: List[str] = [] + for index, audio_bytes in enumerate(audio_urls_list, start=1): + logger.info(f"Analyzing audio #{index}, query: {query}") + content_type = detect_content_type_from_bytes(audio_bytes) + if not content_type.startswith("audio/"): + content_type = "audio/mpeg" + audio_stream = BytesIO(audio_bytes) + try: + response = self.vlm_model.analyze_audio( + audio_input=audio_stream, + system_prompt=system_prompt, + content_type=content_type + ) + except Exception as e: + error_msg_zh = f"音频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。" + error_msg_en = f"Failed to analyze audio {index}: {str(e)}. Please check if the video understanding model is configured correctly." + error_msg = error_msg_zh if self._is_chinese else error_msg_en + raise Exception(error_msg) + + analysis_results.append(response.content) + + return analysis_results + except Exception as e: + logger.error(f"Error analyzing audio: {str(e)}", exc_info=True) + raise Exception(f"Error analyzing audio: {str(e)}") diff --git a/sdk/nexent/core/tools/analyze_image_tool.py b/sdk/nexent/core/tools/analyze_image_tool.py index 3851a896b..f7640a9dc 100644 --- a/sdk/nexent/core/tools/analyze_image_tool.py +++ b/sdk/nexent/core/tools/analyze_image_tool.py @@ -24,17 +24,17 @@ class AnalyzeImageTool(Tool): - """Tool for understanding and analyzing image using a visual language model""" + """Tool for understanding and analyzing images using the image understanding model.""" name = "analyze_image" description = ( - "This tool uses a visual language model to understand images based on your query and then returns a description of the image.\n" + "This tool uses the configured image understanding model to understand images based on your query and then returns a description of the image.\n" "It is used to understand and analyze multiple images, with image sources supporting S3 URLs (s3://bucket/key or /bucket/key), " "HTTP, and HTTPS URLs.\n" "Use this tool when you want to retrieve information contained in an image and provide the image's URL and your query." ) - description_zh = "使用视觉语言模型,根据你的提示词来理解图像,并返回图像的描述。可用于理解和分析多张图片,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + description_zh = "使用图片理解模型,根据你的提示词来理解图像,并返回图像的描述。可用于理解和分析多张图片,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" inputs = { "image_urls_list": { @@ -54,7 +54,7 @@ class AnalyzeImageTool(Tool): "description": "Message observer" }, "vlm_model": { - "description": "The VLM model to use" + "description": "The image understanding model to use" }, "storage_client": { "description": "Storage client for downloading files" @@ -74,7 +74,7 @@ def __init__( default=None, exclude=True), vlm_model: OpenAIVLModel = Field( - description="The VLM model to use", + description="The image understanding model to use", default=None, exclude=True), storage_client: MinIOStorageClient = Field( @@ -130,10 +130,10 @@ def _forward_impl(self, image_urls_list: List[bytes], query: str) -> List[str]: Raises: Exception: If the image cannot be downloaded or analyzed. """ - # Check if VLM model is available + # Check if the image understanding model is available. if self.vlm_model is None: - error_msg_zh = "视觉语言模型(VLM)未配置,请联系管理员配置VLM模型后重试" - error_msg_en = "Vision Language Model (VLM) is not configured. Please contact your administrator to configure the VLM model and try again." + error_msg_zh = "图片理解模型未配置,请联系管理员配置图片理解模型后重试" + error_msg_en = "Image understanding model is not configured. Please contact your administrator to configure the image understanding model and try again." error_msg = error_msg_zh if self._is_chinese else error_msg_en logger.error(error_msg) raise Exception(error_msg) @@ -170,8 +170,8 @@ def _forward_impl(self, image_urls_list: List[bytes], query: str) -> List[str]: system_prompt=system_prompt ) except Exception as e: - error_msg_zh = f"图片{index}分析失败: {str(e)}。请检查VLM模型配置是否正确。" - error_msg_en = f"Failed to analyze image {index}: {str(e)}. Please check if the VLM model is configured correctly." + error_msg_zh = f"图片{index}分析失败: {str(e)}。请检查图片理解模型配置是否正确。" + error_msg_en = f"Failed to analyze image {index}: {str(e)}. Please check if the image understanding model is configured correctly." error_msg = error_msg_zh if self._is_chinese else error_msg_en raise Exception(error_msg) diff --git a/sdk/nexent/core/tools/analyze_video_tool.py b/sdk/nexent/core/tools/analyze_video_tool.py new file mode 100644 index 000000000..3dc033551 --- /dev/null +++ b/sdk/nexent/core/tools/analyze_video_tool.py @@ -0,0 +1,156 @@ +""" +Analyze Video Tool + +Analyze videos using the configured video understanding model. +Supports videos from S3, HTTP, and HTTPS URLs. +""" + +import logging +from io import BytesIO +from typing import List + +from jinja2 import StrictUndefined, Template +from pydantic import Field +from smolagents.tools import Tool + +from ...core.models import OpenAIVLModel +from ...core.utils.observer import MessageObserver, ProcessType +from ...core.utils.prompt_template_utils import get_prompt_template +from ...core.utils.tools_common_message import ToolCategory, ToolSign +from ...multi_modal.load_save_object import LoadSaveObjectManager +from ...multi_modal.utils import detect_content_type_from_bytes +from ...storage import MinIOStorageClient + +logger = logging.getLogger("analyze_video_tool") + + +class AnalyzeVideoTool(Tool): + """Tool for understanding and analyzing videos using the video understanding model.""" + + name = "analyze_video" + description = ( + "This tool uses the configured video understanding model to understand videos based on your query and then returns a video analysis result.\n" + "It is used to understand and analyze multiple videos, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " + "HTTP, and HTTPS URLs.\n" + "Use this tool when you want to retrieve information contained in a video and provide the video's URL and your query." + ) + description_zh = ( + "使用视频理解模型,根据你的提示词来理解视频,并返回视频分析结果。" + "可用于理解和分析多个视频,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + ) + + inputs = { + "video_urls_list": { + "type": "array", + "description": "List of video URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", + "description_zh": "列表形式输入视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。" + }, + "query": { + "type": "string", + "description": "User's question to guide the video analysis", + "description_zh": "用户用于指导视频分析的问题" + } + } + + init_param_descriptions = { + "observer": {"description": "Message observer"}, + "vlm_model": {"description": "The video understanding model to use"}, + "storage_client": {"description": "Storage client for downloading files"}, + "validate_url_access": { + "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" + } + } + output_type = "array" + category = ToolCategory.MULTIMODAL.value + tool_sign = ToolSign.MULTIMODAL_OPERATION.value + + def __init__( + self, + observer: MessageObserver = Field( + description="Message observer", + default=None, + exclude=True), + vlm_model: OpenAIVLModel = Field( + description="The video understanding model to use", + default=None, + exclude=True), + storage_client: MinIOStorageClient = Field( + description="Storage client for downloading files from S3 URLs, HTTP URLs, and HTTPS URLs.", + default=None, + exclude=True), + validate_url_access: callable = Field( + description="Callback function to validate URL access permissions", + default=None, + exclude=True) + ): + super().__init__() + self.observer = observer + self.vlm_model = vlm_model + self.storage_client = storage_client + self._is_chinese = bool(observer and observer.lang == "zh") + + validate_callback = None + if validate_url_access is not None and callable(validate_url_access): + validate_callback = validate_url_access + self.mm = LoadSaveObjectManager( + storage_client=self.storage_client, + validate_url_access=validate_callback + ) + self.forward = self.mm.load_object( + input_names=["video_urls_list"])(self._forward_impl) + + self.running_prompt_zh = "正在分析视频..." + self.running_prompt_en = "Analyzing video..." + + def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: + """Analyze videos and return one result per video input.""" + if self.vlm_model is None: + error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试" + error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again." + error_msg = error_msg_zh if self._is_chinese else error_msg_en + logger.error(error_msg) + raise Exception(error_msg) + + if self.observer: + running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en + self.observer.add_message("", ProcessType.TOOL, running_prompt) + + if video_urls_list is None: + raise ValueError("video_urls cannot be None") + if not isinstance(video_urls_list, list): + raise ValueError("video_urls must be a list of bytes") + if not video_urls_list: + raise ValueError("video_urls must contain at least one video") + + language = self.observer.lang if self.observer else "en" + prompts = get_prompt_template( + template_type='analyze_video', language=language) + system_prompt = Template( + prompts['system_prompt'], undefined=StrictUndefined).render({'query': query}) + + try: + analysis_results: List[str] = [] + for index, video_bytes in enumerate(video_urls_list, start=1): + logger.info(f"Analyzing video #{index}, query: {query}") + content_type = detect_content_type_from_bytes(video_bytes) + if not content_type.startswith("video/"): + content_type = "video/mp4" + video_stream = BytesIO(video_bytes) + try: + response = self.vlm_model.analyze_video( + video_input=video_stream, + system_prompt=system_prompt, + content_type=content_type + ) + except Exception as e: + error_msg_zh = f"视频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。" + error_msg_en = f"Failed to analyze video {index}: {str(e)}. Please check if the video understanding model is configured correctly." + error_msg = error_msg_zh if self._is_chinese else error_msg_en + raise Exception(error_msg) + + analysis_results.append(response.content) + + return analysis_results + except Exception as e: + logger.error(f"Error analyzing video: {str(e)}", exc_info=True) + raise Exception(f"Error analyzing video: {str(e)}") diff --git a/sdk/nexent/core/utils/prompt_template_utils.py b/sdk/nexent/core/utils/prompt_template_utils.py index ad06e9119..24b273876 100644 --- a/sdk/nexent/core/utils/prompt_template_utils.py +++ b/sdk/nexent/core/utils/prompt_template_utils.py @@ -17,6 +17,14 @@ LANGUAGE["ZH"]: 'core/prompts/analyze_image_zh.yaml', LANGUAGE["EN"]: 'core/prompts/analyze_image_en.yaml' }, + 'analyze_audio': { + LANGUAGE["ZH"]: 'core/prompts/analyze_audio_zh.yaml', + LANGUAGE["EN"]: 'core/prompts/analyze_audio_en.yaml' + }, + 'analyze_video': { + LANGUAGE["ZH"]: 'core/prompts/analyze_video_zh.yaml', + LANGUAGE["EN"]: 'core/prompts/analyze_video_en.yaml' + }, 'analyze_file': { LANGUAGE["ZH"]: 'core/prompts/analyze_file_zh.yaml', LANGUAGE["EN"]: 'core/prompts/analyze_file_en.yaml' @@ -30,6 +38,8 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw Args: template_type: Template type, supports the following values: - 'analyze_image': Analyze image template + - 'analyze_audio': Analyze audio template + - 'analyze_video': Analyze video template - 'analyze_file': Analyze file template (for text files) language: Language code ('zh' or 'en') **kwargs: Additional parameters, for agent type need to pass is_manager parameter @@ -52,4 +62,4 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw # Read and return template content with open(absolute_template_path, 'r', encoding='utf-8') as f: - return yaml.safe_load(f) \ No newline at end of file + return yaml.safe_load(f) diff --git a/sdk/nexent/multi_modal/utils.py b/sdk/nexent/multi_modal/utils.py index e118f6940..bcd6cdd35 100644 --- a/sdk/nexent/multi_modal/utils.py +++ b/sdk/nexent/multi_modal/utils.py @@ -34,10 +34,10 @@ def is_url(url: str) -> Optional[UrlType]: if url.startswith("https://"): return "https" - if url.startswith("s3://"): - bucket_path = url.replace("s3://", "", 1) + if url.startswith("s3://") or url.startswith("s3:/"): + bucket_path = url.replace("s3://", "", 1) if url.startswith("s3://") else url.replace("s3:/", "", 1).lstrip("/") bucket_object = bucket_path.split("/", 1) - if len(bucket_object) == 2 and all(bucket_object): + if len(bucket_object) == 2 and all(bucket_object) and ":" not in bucket_object[0]: return "s3" return None @@ -321,6 +321,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]: Supports formats: - s3://bucket/key + - s3:/bucket/key - /bucket/key (MinIO path format) Args: @@ -335,11 +336,16 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]: if not s3_url: raise ValueError("S3 URL cannot be empty") - if s3_url.startswith('s3://'): - parts = s3_url.replace('s3://', '').split('/', 1) + if s3_url.startswith('s3://') or s3_url.startswith('s3:/'): + normalized_url = ( + s3_url.replace('s3://', '', 1) + if s3_url.startswith('s3://') + else s3_url.replace('s3:/', '', 1).lstrip('/') + ) + parts = normalized_url.split('/', 1) if len(parts) == 2: bucket, object_name = parts - if not bucket or not object_name: + if not bucket or not object_name or ":" in bucket: raise ValueError(f"Invalid s3:// URL format: {s3_url}") return bucket, object_name raise ValueError(f"Invalid s3:// URL format: {s3_url}") @@ -351,4 +357,4 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]: return bucket, object_name raise ValueError(f"Invalid path format: {s3_url}") - raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...") \ No newline at end of file + raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...") diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index 5817fbe27..20340f2ea 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -30,6 +30,21 @@ class ValidationError(Exception): pass +class MCPConnectionError(Exception): + """Mock MCPConnectionError for testing.""" + pass + + +class NotFoundException(Exception): + """Mock NotFoundException for testing.""" + pass + + +class ToolExecutionException(Exception): + """Mock ToolExecutionException for testing.""" + pass + + consts_model_module = types.ModuleType("consts.model") consts_model_module.HistoryItem = HistoryItem sys.modules["consts.model"] = consts_model_module @@ -37,6 +52,9 @@ class ValidationError(Exception): # Mock consts.exceptions module with ValidationError consts_exceptions_module = types.ModuleType("consts.exceptions") consts_exceptions_module.ValidationError = ValidationError +consts_exceptions_module.MCPConnectionError = MCPConnectionError +consts_exceptions_module.NotFoundException = NotFoundException +consts_exceptions_module.ToolExecutionException = ToolExecutionException sys.modules["consts.exceptions"] = consts_exceptions_module # Also add model and exceptions to consts module attributes @@ -165,7 +183,9 @@ def _create_stub_module(name: str, **attrs): services_module = _create_stub_module("services") sys.modules['services'] = services_module sys.modules['services.image_service'] = _create_stub_module( - "services.image_service", get_vlm_model=MagicMock(return_value="stub_vlm") + "services.image_service", + get_vlm_model=MagicMock(return_value="stub_vlm"), + get_video_understanding_model=MagicMock(return_value="stub_video_vlm"), ) sys.modules['services.memory_config_service'] = MagicMock() # Extend services hierarchy with additional stubs @@ -250,6 +270,7 @@ def _create_stub_module(name: str, **attrs): _extract_url_from_card, _build_external_agent_config, _get_external_a2a_agents, + _build_internal_s3_url, _format_minio_files_for_content, _convert_history_with_minio_files, ) @@ -727,6 +748,48 @@ async def test_create_tool_config_list_with_analyze_image_tool(self): assert "validate_url_access" in mock_tool_instance.metadata assert callable(mock_tool_instance.metadata["validate_url_access"]) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "class_name,tool_name", + [ + ("AnalyzeAudioTool", "analyze_audio"), + ("AnalyzeVideoTool", "analyze_video"), + ], + ) + async def test_create_tool_config_list_with_audio_video_tools(self, class_name, tool_name): + """Ensure audio/video tools receive video understanding model metadata.""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = class_name + mock_tool_config.return_value = mock_tool_instance + + with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_video_understanding_model') as mock_get_video_model, \ + patch('backend.agents.create_agent_info.minio_client', new_callable=MagicMock): + + mock_search_tools.return_value = [ + { + "class_name": class_name, + "name": tool_name, + "description": "Analyze media tool", + "inputs": "string", + "output_type": "string", + "params": [{"name": "prompt", "default": "describe"}], + "source": "local", + "usage": None + } + ] + mock_get_video_model.return_value = "mock_video_model" + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + assert len(result) == 1 + assert result[0] is mock_tool_instance + mock_get_video_model.assert_called_once_with(tenant_id="tenant_1") + assert mock_tool_instance.metadata["vlm_model"] == "mock_video_model" + assert "storage_client" in mock_tool_instance.metadata + assert callable(mock_tool_instance.metadata["validate_url_access"]) + @pytest.mark.asyncio async def test_create_tool_config_list_with_analyze_text_file_tool(self): """Ensure AnalyzeTextFileTool receives text-specific metadata.""" @@ -3297,6 +3360,26 @@ async def test_create_agent_run_info_is_need_auth_true_includes_token(self): class TestJoinMinioFileDescriptionToQuery: """Tests for the join_minio_file_description_to_query function""" + def test_build_internal_s3_url_prefers_object_name(self): + file = { + "object_name": "attachments/user/image.png", + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + + result = _build_internal_s3_url(file) + + assert result.endswith("/attachments/user/image.png") + assert result.startswith("s3://") + + def test_build_internal_s3_url_rejects_blob_preview_url(self): + file = { + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + + assert _build_internal_s3_url(file) == "" + @pytest.mark.asyncio async def test_join_minio_file_description_to_query_with_files(self): """Test case with file descriptions""" @@ -3345,6 +3428,40 @@ async def test_join_minio_file_description_to_query_no_descriptions(self): assert result == "test query" + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_prefers_object_name_over_blob_url(self): + """Uploaded images should be exposed to internal tools through MinIO, not browser blob URLs.""" + minio_files = [ + { + "object_name": "attachments/user/image.png", + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + ] + query = "describe the image" + + result = await join_minio_file_description_to_query(minio_files, query) + + assert "blob:http" not in result + assert "File name: image.png" in result + assert "attachments/user/image.png" in result + assert "S3 URL: s3://" in result + + @pytest.mark.asyncio + async def test_join_minio_file_description_to_query_skips_blob_only_file(self): + """Browser-only preview URLs cannot be used by internal tools.""" + minio_files = [ + { + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + ] + query = "describe the image" + + result = await join_minio_file_description_to_query(minio_files, query) + + assert result == query + @pytest.mark.asyncio async def test_join_minio_file_description_to_query_deduplication_current(self): """Test that duplicate files in current message are de-duplicated by URL""" @@ -4455,6 +4572,21 @@ def test_format_minio_files_for_content_single_file_without_presigned_url(self): result = _format_minio_files_for_content(minio_files) assert result == "\n[Attached files]:\n - file.txt: s3:/bucket/file.txt" + def test_format_minio_files_for_content_uses_object_name_for_blob_url(self): + """Use uploaded object_name instead of browser-only blob preview URL.""" + minio_files = [ + { + "object_name": "attachments/user/image.png", + "url": "blob:http://localhost:3000/preview", + "name": "image.png", + } + ] + + result = _format_minio_files_for_content(minio_files) + + assert "blob:http" not in result + assert "attachments/user/image.png" in result + def test_format_minio_files_for_content_multiple_files(self): """Test case for multiple files""" minio_files = [ diff --git a/test/backend/services/providers/test_silicon_provider.py b/test/backend/services/providers/test_silicon_provider.py index b947040c3..c596643b2 100644 --- a/test/backend/services/providers/test_silicon_provider.py +++ b/test/backend/services/providers/test_silicon_provider.py @@ -66,7 +66,13 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): mock_response.status_code = 200 mock_response.json.return_value = { "data": [ - {"id": "gpt-4v", "name": "GPT-4 Vision"}, + {"id": "deepseek-ai/DeepSeek-R1", "name": "DeepSeek R1"}, + {"id": "Qwen/Qwen2.5-VL-72B-Instruct", "name": "Qwen2.5 VL"}, + {"id": "OpenGVLab/InternVL2-26B", "name": "InternVL2 26B"}, + {"id": "Pro/moonshotai/Kimi-K2.6", "name": "Kimi K2.6"}, + {"id": "Pro/moonshotai/Kimi-K2.5", "name": "Kimi K2.5"}, + {"id": "Qwen/Qwen3.6-27B", "name": "Qwen3.6 27B"}, + {"id": "Qwen/Qwen3.6-35B-A3B", "name": "Qwen3.6 35B A3B"}, ] } mock_response.raise_for_status = MagicMock() @@ -95,10 +101,66 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): result = await provider.get_models(provider_config) - assert len(result) == 1 - assert result[0]["id"] == "gpt-4v" - assert result[0]["model_type"] == "vlm" - assert result[0]["model_tag"] == "chat" + assert [model["id"] for model in result] == [ + "Qwen/Qwen2.5-VL-72B-Instruct", + "OpenGVLab/InternVL2-26B", + "Pro/moonshotai/Kimi-K2.6", + "Pro/moonshotai/Kimi-K2.5", + "Qwen/Qwen3.6-27B", + "Qwen/Qwen3.6-35B-A3B", + ] + assert all(model["model_type"] == "vlm" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + + @pytest.mark.asyncio + async def test_get_models_vlm3_only_returns_omni_models(self, mocker: MockFixture): + """Test that SiliconFlow video understanding models are restricted to Omni models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + {"id": "Qwen/Qwen3-VL-32B-Instruct", "name": "Qwen3 VL"}, + {"id": "Qwen/Qwen3-Omni-30B-A3B-Instruct", "name": "Qwen3 Omni Instruct"}, + {"id": "Qwen/Qwen3-Omni-30B-A3B-Thinking", "name": "Qwen3 Omni Thinking"}, + {"id": "Qwen/Qwen3-Omni-30B-A3B-Captioner", "name": "Qwen3 Omni Captioner"}, + {"id": "zai-org/GLM-4.5V", "name": "GLM 4.5V"}, + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.silicon_provider.httpx.AsyncClient", + return_value=mock_cm + ) + mocker.patch( + "backend.services.providers.silicon_provider.SILICON_GET_URL", + "https://api.siliconflow.com/v1/models" + ) + + provider = SiliconModelProvider() + provider_config = { + "model_type": "vlm3", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert [model["id"] for model in result] == [ + "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "Qwen/Qwen3-Omni-30B-A3B-Thinking", + "Qwen/Qwen3-Omni-30B-A3B-Captioner", + ] + assert all(model["model_type"] == "vlm3" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + call_args = mock_client.get.call_args + assert "sub_type=chat" in call_args[0][0] @pytest.mark.asyncio async def test_get_models_embedding_success(self, mocker: MockFixture): diff --git a/test/backend/services/test_image_service.py b/test/backend/services/test_image_service.py index 1de8d49fd..34f24568c 100644 --- a/test/backend/services/test_image_service.py +++ b/test/backend/services/test_image_service.py @@ -13,10 +13,17 @@ helpers_env = bootstrap_test_env() helpers_env["mock_const"].DATA_PROCESS_SERVICE = "http://mock-data-process-service" -helpers_env["mock_const"].MODEL_CONFIG_MAPPING = {"vlm": "vlm_model_config"} +helpers_env["mock_const"].MODEL_CONFIG_MAPPING = { + "vlm": "vlm_model_config", + "vlm3": "video_model_config", +} mock_const = helpers_env["mock_const"] -from services.image_service import get_vlm_model, proxy_image_impl +from services.image_service import get_image_understanding_model, get_video_understanding_model, get_vlm_model, proxy_image_impl + +image_service_module = sys.modules[get_vlm_model.__module__] +if "services" in sys.modules: + setattr(sys.modules["services"], "image_service", image_service_module) # Sample test data test_url = "https://example.com/image.jpg" @@ -50,7 +57,7 @@ async def test_proxy_image_impl_success(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -85,7 +92,7 @@ async def test_proxy_image_impl_remote_error(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -118,7 +125,7 @@ async def test_proxy_image_impl_500_error(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -146,7 +153,7 @@ async def test_proxy_image_impl_connection_exception(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function - should raise the exception @@ -178,7 +185,7 @@ async def test_proxy_image_impl_with_special_chars(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -213,7 +220,7 @@ async def test_proxy_image_impl_json_parse_error(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function - should raise the exception @@ -253,7 +260,7 @@ async def test_proxy_image_impl_different_status_codes(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function @@ -289,7 +296,7 @@ async def test_proxy_image_impl_url_encoding(): mock_client_session.__aenter__.return_value = mock_session # Patch the ClientSession - with patch('services.image_service.aiohttp.ClientSession') as mock_session_class: + with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class: mock_session_class.return_value = mock_client_session # Test the function with encoded URL @@ -305,10 +312,10 @@ async def test_proxy_image_impl_url_encoding(): assert f"url={encoded_url}" in called_url -@patch('services.image_service.OpenAIVLModel') -@patch('services.image_service.MessageObserver') -@patch('services.image_service.get_model_name_from_config') -@patch('services.image_service.tenant_config_manager') +@patch.object(image_service_module, 'OpenAIVLModel') +@patch.object(image_service_module, 'MessageObserver') +@patch.object(image_service_module, 'get_model_name_from_config') +@patch.object(image_service_module, 'tenant_config_manager') def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model): """Ensure get_vlm_model builds OpenAIVLModel with tenant config.""" mock_config = { @@ -324,7 +331,7 @@ def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name, result = get_vlm_model("tenant-1") mock_tenant_config_manager.get_model_config.assert_called_once_with( - key=mock_const.MODEL_CONFIG_MAPPING["vlm"], + key="vlm_model_config", tenant_id="tenant-1" ) mock_message_observer.assert_called_once_with() @@ -342,10 +349,10 @@ def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name, assert result == mock_model_instance -@patch('services.image_service.OpenAIVLModel') -@patch('services.image_service.MessageObserver') -@patch('services.image_service.get_model_name_from_config') -@patch('services.image_service.tenant_config_manager') +@patch.object(image_service_module, 'OpenAIVLModel') +@patch.object(image_service_module, 'MessageObserver') +@patch.object(image_service_module, 'get_model_name_from_config') +@patch.object(image_service_module, 'tenant_config_manager') def test_get_vlm_model_with_none_config(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model): """Return None when tenant config is None.""" mock_tenant_config_manager.get_model_config.return_value = None @@ -359,3 +366,40 @@ def test_get_vlm_model_with_none_config(mock_tenant_config_manager, mock_get_mod # OpenAIVLModel should not be called when config is None mock_openai_vl_model.assert_not_called() assert result is None + + +@patch.object(image_service_module, 'get_vlm_model') +def test_get_image_understanding_model_uses_first_multimodal_slot(mock_get_vlm_model): + """Ensure the image understanding alias keeps using the first multimodal slot.""" + mock_get_vlm_model.return_value = "image-understanding-model" + + result = get_image_understanding_model("tenant-1") + + mock_get_vlm_model.assert_called_once_with(tenant_id="tenant-1") + assert result == "image-understanding-model" + + +@patch.object(image_service_module, 'OpenAIVLModel') +@patch.object(image_service_module, 'MessageObserver') +@patch.object(image_service_module, 'get_model_name_from_config') +@patch.object(image_service_module, 'tenant_config_manager') +def test_get_video_understanding_model_success(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model): + """Ensure video understanding tools use the third multimodal model slot.""" + mock_config = { + "base_url": "https://mock-video-api", + "api_key": "secret", + "model_name": "video-model" + } + mock_tenant_config_manager.get_model_config.return_value = mock_config + mock_get_model_name.return_value = "video-model" + mock_model_instance = MagicMock() + mock_openai_vl_model.return_value = mock_model_instance + + result = get_video_understanding_model("tenant-1") + + mock_tenant_config_manager.get_model_config.assert_called_once_with( + key="video_model_config", + tenant_id="tenant-1" + ) + mock_openai_vl_model.assert_called_once() + assert result == mock_model_instance diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 6e504e90a..bd56cce40 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -660,17 +660,11 @@ async def test_batch_create_models_for_tenant_flow(): existing = [ {"model_id": "del-id", "model_repo": "silicon", "model_name": "delete"}, - {"model_id": "keep-id", "model_repo": "silicon", "model_name": "keep"}, + {"model_id": "keep-id", "model_repo": "silicon", "model_name": "keep", "max_tokens": 1024}, ] - def get_by_display(display_name, tenant_id): - if display_name == "silicon/keep": - return {"model_id": "keep-id", "max_tokens": 1024} - return None - with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing) as mock_get_existing, \ mock.patch.object(svc, "delete_model_record") as mock_delete, \ - mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \ mock.patch.object(svc, "update_model_record") as mock_update, \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"prepared": True})) as mock_prep, \ mock.patch.object(svc, "create_model_record") as mock_create: @@ -679,13 +673,35 @@ def get_by_display(display_name, tenant_id): mock_get_existing.assert_called_once_with("t1", "silicon", "llm") mock_delete.assert_called_once_with("del-id", "u1", "t1") - mock_get_by_display.assert_any_call("silicon/keep", "t1") mock_update.assert_called_once_with( "keep-id", {"max_tokens": 4096}, "u1") mock_prep.assert_awaited() mock_create.assert_called_once() +@pytest.mark.asyncio +async def test_batch_create_models_uses_requested_type_for_each_model(): + svc = import_svc() + + batch_payload = { + "provider": "silicon", + "type": "vlm", + "models": [ + {"id": "Qwen/Qwen2.5-VL-72B-Instruct", "model_type": "llm", "max_tokens": 4096}, + ], + "api_key": "k", + } + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"prepared": True})) as mock_prep, \ + mock.patch.object(svc, "create_model_record"): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + prepared_model = mock_prep.call_args.kwargs["model"] + assert prepared_model["model_type"] == "vlm" + + @pytest.mark.asyncio async def test_batch_create_models_max_tokens_update(): """Test batch_create_models updates max_tokens when display_name exists and max_tokens changed (covers lines 160->173, 168->171)""" @@ -702,22 +718,16 @@ async def test_batch_create_models_max_tokens_update(): "api_key": "k", } - def get_by_display(display_name, tenant_id): - if display_name == "silicon/model1": - # Different from new value - return {"model_id": "id1", "max_tokens": 4096} - elif display_name == "silicon/model2": - return {"model_id": "id2", "max_tokens": 4096} # Same as new value - elif display_name == "silicon/model3": - # Existing has value, new is None - return {"model_id": "id3", "max_tokens": 2048} - return None + existing = [ + {"model_id": "id1", "model_repo": "silicon", "model_name": "model1", "max_tokens": 4096}, + {"model_id": "id2", "model_repo": "silicon", "model_name": "model2", "max_tokens": 4096}, + {"model_id": "id3", "model_repo": "silicon", "model_name": "model3", "max_tokens": 2048}, + ] - with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \ + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing), \ mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", side_effect=lambda x: ("silicon", x.split("/")[1] if "/" in x else x)), \ mock.patch.object(svc, "add_repo_to_name", side_effect=lambda r, n: f"{r}/{n}"), \ - mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \ mock.patch.object(svc, "update_model_record") as mock_update, \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ mock.patch.object(svc, "create_model_record", return_value=True): diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 3cbdcee2b..44bcb7681 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -15,6 +15,54 @@ minio_client_mock = MagicMock() sys.modules['boto3'] = boto3_mock +fastmcp_mock = types.ModuleType('fastmcp') +fastmcp_mock.__path__ = [] + + +class MockFastMcpClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def is_connected(self): + return True + + async def call_tool(self, *args, **kwargs): + return MagicMock() + + +class MockSSETransport: + def __init__(self, *args, **kwargs): + pass + + +class MockStreamableHttpTransport: + def __init__(self, *args, **kwargs): + pass + + +fastmcp_mock.Client = MockFastMcpClient +fastmcp_client_mock = types.ModuleType('fastmcp.client') +fastmcp_client_mock.__path__ = [] +fastmcp_transports_mock = types.ModuleType('fastmcp.client.transports') +fastmcp_transports_mock.SSETransport = MockSSETransport +fastmcp_transports_mock.StreamableHttpTransport = MockStreamableHttpTransport +sys.modules['fastmcp'] = fastmcp_mock +sys.modules['fastmcp.client'] = fastmcp_client_mock +sys.modules['fastmcp.client.transports'] = fastmcp_transports_mock + +mcpadapt_mock = types.ModuleType('mcpadapt') +mcpadapt_mock.__path__ = [] +mcpadapt_smolagents_adapter_mock = types.ModuleType('mcpadapt.smolagents_adapter') +mcpadapt_smolagents_adapter_mock._sanitize_function_name = lambda name: name +sys.modules['mcpadapt'] = mcpadapt_mock +sys.modules['mcpadapt.smolagents_adapter'] = mcpadapt_smolagents_adapter_mock + # Patch smolagents and its sub-modules before importing consts.model to avoid ImportError mock_smolagents = MagicMock() sys.modules['smolagents'] = mock_smolagents @@ -323,7 +371,7 @@ def validate(self): 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_vector_db_core': MagicMock(), 'ElasticSearchService': MagicMock()}, 'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()}, - 'image_service': {'get_vlm_model': MagicMock()} + 'image_service': {'get_vlm_model': MagicMock(), 'get_video_understanding_model': MagicMock()} } for service_name, attrs in services_modules.items(): service_module = types.ModuleType(f'services.{service_name}') @@ -354,6 +402,7 @@ def validate(self): patch('services.tenant_config_service.build_knowledge_name_mapping', MagicMock()).start() patch('services.image_service.get_vlm_model', MagicMock()).start() +patch('services.image_service.get_video_understanding_model', MagicMock()).start() patch('backend.database.knowledge_db.get_knowledge_name_map_by_index_names', MagicMock()).start() patch('backend.services.tool_configuration_service.get_embedding_model_by_index_name', MagicMock()).start() @@ -2676,6 +2725,63 @@ def test_validate_local_tool_analyze_image_missing_user(self, mock_get_class): ) +class TestValidateLocalToolAnalyzeAudioVideo: + """Test cases for _validate_local_tool with analyze_audio/analyze_video tools.""" + + @pytest.mark.parametrize("tool_name", ["analyze_audio", "analyze_video"]) + @patch('backend.services.tool_configuration_service.minio_client') + @patch('backend.services.tool_configuration_service.get_video_understanding_model') + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + @patch('backend.services.tool_configuration_service.inspect.signature') + def test_validate_local_tool_analyze_audio_video_success( + self, mock_signature, mock_get_class, mock_get_video_model, mock_minio_client, tool_name): + mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = f"{tool_name} result" + mock_tool_class.return_value = mock_tool_instance + mock_get_class.return_value = mock_tool_class + mock_get_video_model.return_value = "mock_video_model" + + mock_sig = Mock() + mock_sig.parameters = {} + mock_signature.return_value = mock_sig + + from backend.services.tool_configuration_service import _validate_local_tool + + result = _validate_local_tool( + tool_name, + {"media": "bytes"}, + {"prompt": "describe"}, + "tenant1", + "user1" + ) + + assert result == f"{tool_name} result" + mock_get_video_model.assert_called_once_with(tenant_id="tenant1") + call_kwargs = mock_tool_class.call_args.kwargs + assert call_kwargs["vlm_model"] == "mock_video_model" + assert "storage_client" in call_kwargs + assert callable(call_kwargs["validate_url_access"]) + mock_tool_instance.forward.assert_called_once_with(media="bytes") + + @pytest.mark.parametrize("tool_name", ["analyze_audio", "analyze_video"]) + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + def test_validate_local_tool_analyze_audio_video_missing_tenant(self, mock_get_class, tool_name): + mock_get_class.return_value = Mock() + + from backend.services.tool_configuration_service import _validate_local_tool + + with pytest.raises(ToolExecutionException, + match=f"Tenant ID and User ID are required for {tool_name} validation"): + _validate_local_tool( + tool_name, + {"media": "bytes"}, + {"prompt": "describe"}, + None, + "user1" + ) + + class TestValidateLocalToolDatamateSearchTool: """Test cases for _validate_local_tool function with datamate_search_tool""" diff --git a/test/common/test_mocks.py b/test/common/test_mocks.py index c87b52859..c57941780 100644 --- a/test/common/test_mocks.py +++ b/test/common/test_mocks.py @@ -112,6 +112,8 @@ def setup_common_mocks(): "multiEmbedding": "MULTI_EMBEDDING_ID", "rerank": "RERANK_ID", "vlm": "VLM_ID", + "vlm2": "VLM2_ID", + "vlm3": "VLM3_ID", "stt": "STT_ID", "tts": "TTS_ID" } diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 9853b9eca..2d588b62d 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -2429,6 +2429,52 @@ def test_create_local_tool_analyze_image(self, nexent_agent_instance): assert call_kwargs["param1"] == "value1" assert result == mock_tool_instance + @pytest.mark.parametrize( + "class_name,tool_name", + [ + ("AnalyzeAudioTool", "analyze_audio"), + ("AnalyzeVideoTool", "analyze_video"), + ], + ) + def test_create_local_tool_analyze_audio_video(self, nexent_agent_instance, class_name, tool_name): + """Test successful audio/video analysis tool creation.""" + mock_tool_class = MagicMock() + mock_tool_instance = MagicMock() + mock_tool_class.return_value = mock_tool_instance + + tool_config = ToolConfig( + class_name=class_name, + name=tool_name, + description="desc", + inputs="{}", + output_type="string", + params={"param1": "value1"}, + source="local", + metadata={ + "vlm_model": ["video-understanding-model"], + "storage_client": "storage" + } + ) + + original_value = nexent_agent.__dict__.get(class_name) + nexent_agent.__dict__[class_name] = mock_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__[class_name] = original_value + elif class_name in nexent_agent.__dict__: + del nexent_agent.__dict__[class_name] + + mock_tool_class.assert_called_once() + call_kwargs = mock_tool_class.call_args[1] + assert call_kwargs["observer"] == nexent_agent_instance.observer + assert call_kwargs["vlm_model"] == ["video-understanding-model"] + assert call_kwargs["storage_client"] == "storage" + assert call_kwargs["param1"] == "value1" + assert result == mock_tool_instance + def test_create_local_tool_analyze_text_file_with_validate_url_access_none(self, nexent_agent_instance): """Test AnalyzeTextFileTool creation with validate_url_access not in metadata (None).""" mock_tool_class = MagicMock() diff --git a/test/sdk/core/models/test_openai_vlm.py b/test/sdk/core/models/test_openai_vlm.py index f1db49380..4f7104290 100644 --- a/test/sdk/core/models/test_openai_vlm.py +++ b/test/sdk/core/models/test_openai_vlm.py @@ -62,7 +62,13 @@ def vl_model_instance(): """Return an OpenAIVLModel instance with minimal viable attributes for tests.""" observer = MagicMock() - model = ImportedOpenAIVLModel(observer=observer, ssl_verify=True) + model = ImportedOpenAIVLModel( + observer=observer, + model_id="dummy-model", + api_key="dummy-key", + api_base="https://example.test", + ssl_verify=True, + ) # Inject dummy attributes required by the method under test model.model_id = "dummy-model" @@ -321,3 +327,55 @@ def test_analyze_image_calls_prepare_image_message(vl_model_instance, tmp_path): # Verify prepare_image_message was called with correct arguments mock_prepare.assert_called_once_with(str(test_image), custom_prompt) + + +def test_prepare_media_message_audio(vl_model_instance): + audio_stream = MagicMock() + audio_stream.read.return_value = b"audio bytes" + + messages = vl_model_instance.prepare_media_message( + audio_stream, + media_type="audio", + content_type="audio/mpeg", + system_prompt="Listen carefully", + ) + + assert messages[0]["content"][0]["type"] == "audio_url" + assert messages[0]["content"][0]["audio_url"]["url"].startswith("data:audio/mpeg;base64,") + assert messages[0]["content"][1] == {"type": "text", "text": "Listen carefully"} + + +def test_prepare_media_message_video(vl_model_instance): + video_stream = MagicMock() + video_stream.read.return_value = b"video bytes" + + messages = vl_model_instance.prepare_media_message( + video_stream, + media_type="video", + content_type="video/mp4", + system_prompt="Watch carefully", + ) + + assert messages[0]["content"][0]["type"] == "video_url" + assert messages[0]["content"][0]["video_url"]["url"].startswith("data:video/mp4;base64,") + assert messages[0]["content"][0]["video_url"]["max_frames"] == 16 + assert messages[0]["content"][0]["video_url"]["fps"] == 1 + assert messages[0]["content"][1] == {"type": "text", "text": "Watch carefully"} + + +def test_analyze_audio_calls_prepare_media_message(vl_model_instance): + with patch.object(vl_model_instance, "prepare_media_message", return_value=[{"role": "user", "content": "test"}]) as mock_prepare: + vl_model_instance.__call__ = MagicMock(return_value=MagicMock()) + + vl_model_instance.analyze_audio("audio.mp3", system_prompt="Analyze", content_type="audio/mpeg") + + mock_prepare.assert_called_once_with("audio.mp3", "audio", "audio/mpeg", "Analyze") + + +def test_analyze_video_calls_prepare_media_message(vl_model_instance): + with patch.object(vl_model_instance, "prepare_media_message", return_value=[{"role": "user", "content": "test"}]) as mock_prepare: + vl_model_instance.__call__ = MagicMock(return_value=MagicMock()) + + vl_model_instance.analyze_video("video.mp4", system_prompt="Analyze", content_type="video/mp4") + + mock_prepare.assert_called_once_with("video.mp4", "video", "video/mp4", "Analyze") diff --git a/test/sdk/core/tools/test_analyze_audio_video_tool.py b/test/sdk/core/tools/test_analyze_audio_video_tool.py new file mode 100644 index 000000000..94401b61d --- /dev/null +++ b/test/sdk/core/tools/test_analyze_audio_video_tool.py @@ -0,0 +1,119 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from sdk.nexent.core.tools import analyze_audio_tool, analyze_video_tool +from sdk.nexent.core.tools.analyze_audio_tool import AnalyzeAudioTool +from sdk.nexent.core.tools.analyze_video_tool import AnalyzeVideoTool +from sdk.nexent.core.utils.observer import MessageObserver, ProcessType + + +@pytest.fixture +def mock_storage_client(): + class DummyStorage: + pass + + return DummyStorage() + + +@pytest.fixture +def mock_vlm_model(): + return MagicMock() + + +@pytest.fixture +def observer_en(): + observer = MagicMock(spec=MessageObserver) + observer.lang = "en" + return observer + + +def test_analyze_audio_uses_video_understanding_model(observer_en, mock_vlm_model, mock_storage_client, monkeypatch): + calls = [] + + def _fake_get_prompt(template_type, language=None, **_): + calls.append((template_type, language)) + return {"system_prompt": "Analyze audio for {{ query }}"} + + monkeypatch.setattr(analyze_audio_tool, "get_prompt_template", _fake_get_prompt) + mock_vlm_model.analyze_audio.return_value = SimpleNamespace(content="audio result") + tool = AnalyzeAudioTool( + observer=observer_en, + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + result = tool._forward_impl([b"ID3audio-bytes"], "what happened?") + + assert result == ["audio result"] + assert calls == [("analyze_audio", "en")] + mock_vlm_model.analyze_audio.assert_called_once() + call_kwargs = mock_vlm_model.analyze_audio.call_args.kwargs + assert hasattr(call_kwargs["audio_input"], "read") + assert call_kwargs["content_type"].startswith("audio/") + observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing audio...") + + +def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_storage_client): + vlm_model = SimpleNamespace( + model_id="Qwen/Qwen3-VL-32B-Instruct", + client_kwargs={"base_url": "https://api.siliconflow.cn/v1"}, + ) + tool = AnalyzeAudioTool( + observer=observer_en, + vlm_model=vlm_model, + storage_client=mock_storage_client, + ) + + with pytest.raises(ValueError) as exc_info: + tool._forward_impl([b"ID3audio-bytes"], "what happened?") + + assert "Please choose a Qwen3-Omni model" in str(exc_info.value) + + +def test_analyze_video_uses_video_understanding_model(observer_en, mock_vlm_model, mock_storage_client, monkeypatch): + calls = [] + + def _fake_get_prompt(template_type, language=None, **_): + calls.append((template_type, language)) + return {"system_prompt": "Analyze video for {{ query }}"} + + monkeypatch.setattr(analyze_video_tool, "get_prompt_template", _fake_get_prompt) + mock_vlm_model.analyze_video.return_value = SimpleNamespace(content="video result") + tool = AnalyzeVideoTool( + observer=observer_en, + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + result = tool._forward_impl([b"\x00\x00\x00\x18ftypmp42video-bytes"], "what happened?") + + assert result == ["video result"] + assert calls == [("analyze_video", "en")] + mock_vlm_model.analyze_video.assert_called_once() + call_kwargs = mock_vlm_model.analyze_video.call_args.kwargs + assert hasattr(call_kwargs["video_input"], "read") + assert call_kwargs["content_type"].startswith("video/") + observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing video...") + + +@pytest.mark.parametrize( + "tool_class,input_name,error_text", + [ + (AnalyzeAudioTool, "audio_urls_list", "Video understanding model is not configured"), + (AnalyzeVideoTool, "video_urls_list", "Video understanding model is not configured"), + ], +) +def test_analyze_audio_video_require_video_understanding_model( + tool_class, input_name, error_text, observer_en, mock_storage_client): + tool = tool_class( + observer=observer_en, + vlm_model=None, + storage_client=mock_storage_client, + ) + + with pytest.raises(Exception) as exc_info: + tool._forward_impl(**{input_name: [b"media"], "query": "question"}) + + assert error_text in str(exc_info.value) diff --git a/test/sdk/core/tools/test_analyze_image_tool.py b/test/sdk/core/tools/test_analyze_image_tool.py index a8598a8ad..63be0ac54 100644 --- a/test/sdk/core/tools/test_analyze_image_tool.py +++ b/test/sdk/core/tools/test_analyze_image_tool.py @@ -136,7 +136,7 @@ def test_forward_impl_vlm_model_none(self, observer_en, mock_storage_client): with pytest.raises(Exception) as exc_info: tool._forward_impl([b"img"], "question") - assert "Vision Language Model (VLM) is not configured" in str( + assert "Image understanding model is not configured" in str( exc_info.value) def test_forward_impl_vlm_model_none_chinese(self, observer_zh, mock_storage_client): @@ -150,7 +150,7 @@ def test_forward_impl_vlm_model_none_chinese(self, observer_zh, mock_storage_cli with pytest.raises(Exception) as exc_info: tool._forward_impl([b"img"], "问题") - assert "视觉语言模型(VLM)未配置" in str(exc_info.value) + assert "图片理解模型未配置" in str(exc_info.value) def test_forward_impl_observer_none_uses_english(self, mock_vlm_model, mock_storage_client): """Test that English is used when observer is None.""" @@ -353,7 +353,7 @@ def test_observer_add_message_not_called_when_none(self, mock_vlm_model, mock_st def test_tool_name_and_description(self, tool): """Test that tool name and description are set correctly.""" assert tool.name == "analyze_image" - assert "visual language model" in tool.description.lower() + assert "image understanding model" in tool.description.lower() assert "image" in tool.description.lower() def test_tool_inputs_schema(self, tool): diff --git a/test/sdk/core/utils/test_prompt_template_utils.py b/test/sdk/core/utils/test_prompt_template_utils.py index c0a3ad634..a50929b8d 100644 --- a/test/sdk/core/utils/test_prompt_template_utils.py +++ b/test/sdk/core/utils/test_prompt_template_utils.py @@ -61,6 +61,28 @@ def test_get_prompt_template_analyze_image_en(self, mock_yaml_load, mock_file): # Verify result assert result == {"system_prompt": "Test prompt", "user_prompt": "User prompt"} + @pytest.mark.parametrize( + "template_type,language,expected_file", + [ + ("analyze_audio", "en", "prompts/analyze_audio_en.yaml"), + ("analyze_audio", "zh", "prompts/analyze_audio_zh.yaml"), + ("analyze_video", "en", "prompts/analyze_video_en.yaml"), + ("analyze_video", "zh", "prompts/analyze_video_zh.yaml"), + ], + ) + @patch('builtins.open', new_callable=mock_open, read_data='system_prompt: "Test prompt"\nuser_prompt: "User prompt"') + @patch('yaml.safe_load') + def test_get_prompt_template_analyze_audio_video( + self, mock_yaml_load, mock_file, template_type, language, expected_file): + """Test get_prompt_template for audio/video templates.""" + mock_yaml_load.return_value = {"system_prompt": "Test prompt", "user_prompt": "User prompt"} + + result = get_prompt_template(template_type=template_type, language=language) + + call_args = mock_file.call_args[0] + assert expected_file in call_args[0].replace('\\', '/') + assert result == {"system_prompt": "Test prompt", "user_prompt": "User prompt"} + @patch('builtins.open', new_callable=mock_open, read_data='system_prompt: "Test prompt"') @patch('yaml.safe_load') @patch('sdk.nexent.core.utils.prompt_template_utils.LANGUAGE', {'ZH': 'zh', 'EN': 'en'}) @@ -174,4 +196,4 @@ def test_get_prompt_template_path_resolution(self, mock_yaml_load, mock_file): assert mock_file.called call_args = mock_file.call_args[0] # Path should be absolute or contain the expected template file - assert 'analyze_image_en.yaml' in call_args[0] \ No newline at end of file + assert 'analyze_image_en.yaml' in call_args[0] diff --git a/test/sdk/multi_modal/test_load_save_object.py b/test/sdk/multi_modal/test_load_save_object.py index 92425791c..1670e6a9d 100644 --- a/test/sdk/multi_modal/test_load_save_object.py +++ b/test/sdk/multi_modal/test_load_save_object.py @@ -26,6 +26,23 @@ def test_get_client_requires_initialized_storage(): manager._get_client() +def test_s3_single_slash_url_supported(): + assert lso.is_url("s3:/bucket/path/to/image.png") == "s3" + assert lso.parse_s3_url("s3:/bucket/path/to/image.png") == ( + "bucket", + "path/to/image.png", + ) + + +def test_s3_blob_preview_url_rejected(): + assert lso.is_url("s3:/blob:http://localhost:3000/preview") is None + + +def test_parse_s3_blob_preview_url_rejected(): + with pytest.raises(ValueError, match="Invalid s3:// URL format"): + lso.parse_s3_url("s3:/blob:http://localhost:3000/preview") + + def test_download_file_from_http(monkeypatch): manager = make_manager() From 5a85a5ba94acfef67aeb43ebf1af829a3ab996a6 Mon Sep 17 00:00:00 2001 From: 827dls <1670704430@qq.com> Date: Sun, 17 May 2026 17:20:11 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E8=A1=A5=E5=85=85test=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../services/test_config_sync_service.py | 44 ++++++++++++++++--- .../test_config_sync_service_voice.py | 32 +++++++++++++- .../services/test_model_management_service.py | 7 ++- 3 files changed, 75 insertions(+), 8 deletions(-) diff --git a/test/backend/services/test_config_sync_service.py b/test/backend/services/test_config_sync_service.py index 0748a71b7..9928d0315 100644 --- a/test/backend/services/test_config_sync_service.py +++ b/test/backend/services/test_config_sync_service.py @@ -1,4 +1,6 @@ import sys +import types +import importlib from unittest.mock import patch, MagicMock, call import pytest @@ -22,6 +24,31 @@ minio_config_mock = MagicMock() minio_config_mock.validate = MagicMock() +if 'consts.const' in sys.modules and not hasattr(sys.modules['consts.const'], 'APP_DESCRIPTION'): + sys.modules.pop('consts.const', None) +if 'consts' in sys.modules and not hasattr(sys.modules['consts'], '__path__'): + sys.modules.pop('consts', None) + +database_client_module = types.ModuleType('database.client') +database_client_module.MinioClient = MagicMock() +database_client_module.minio_client = minio_client_mock +database_client_module.as_dict = MagicMock(side_effect=lambda value: value) +database_client_module.db_client = MagicMock() +database_client_module.db_client.clean_string_values = MagicMock(side_effect=lambda value: value) +database_client_module.get_db_session = MagicMock() +sys.modules['database.client'] = database_client_module +database_package = sys.modules.get('database') or importlib.import_module('database') +setattr(database_package, 'client', database_client_module) +database_model_management_module = types.ModuleType('database.model_management_db') +database_model_management_module.get_model_by_model_id = MagicMock() +database_model_management_module.get_model_id_by_display_name = MagicMock() +database_model_management_module.get_model_records = MagicMock(return_value=[]) +sys.modules['database.model_management_db'] = database_model_management_module +setattr(database_package, 'model_management_db', database_model_management_module) +backend_database_client_module = sys.modules.get('backend.database.client') +if backend_database_client_module is not None and not hasattr(backend_database_client_module, 'minio_client'): + backend_database_client_module.minio_client = minio_client_mock + patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() patch('nexent.storage.minio_config.MinIOStorageConfig', @@ -29,7 +56,7 @@ patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() patch('database.client.MinioClient', return_value=minio_client_mock).start() -patch('backend.database.client.minio_client', minio_client_mock).start() +patch('backend.database.client.minio_client', minio_client_mock, create=True).start() patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start() # Import backend modules after all patches are applied @@ -52,14 +79,17 @@ def service_mocks(): with patch('backend.services.config_sync_service.tenant_config_manager') as mock_tenant_config_manager, \ patch('backend.services.config_sync_service.get_env_key') as mock_get_env_key, \ patch('backend.services.config_sync_service.safe_value') as mock_safe_value, \ + patch('backend.services.config_sync_service.get_model_records') as mock_get_model_records, \ patch('backend.services.config_sync_service.get_model_id_by_display_name') as mock_get_model_id, \ patch('backend.services.config_sync_service.get_model_name_from_config') as mock_get_model_name, \ patch('backend.services.config_sync_service.logger') as mock_logger: + mock_get_model_records.return_value = [] yield { 'tenant_config_manager': mock_tenant_config_manager, 'get_env_key': mock_get_env_key, 'safe_value': mock_safe_value, + 'get_model_records': mock_get_model_records, 'get_model_id': mock_get_model_id, 'get_model_name': mock_get_model_name, 'logger': mock_logger @@ -1336,6 +1366,8 @@ def side_effect(config_key, tenant_id=None): "MULTI_EMBEDDING_ID": {}, "RERANK_ID": {}, "VLM_ID": {}, + "VLM2_ID": {}, + "VLM3_ID": {}, "STT_ID": {}, "TTS_ID": {} } @@ -1348,7 +1380,7 @@ def side_effect(config_key, tenant_id=None): # Assert assert isinstance(result, dict) - assert len(result) == 7 # All model types should be present + assert len(result) == 9 # All model types should be present # Verify successful configs assert result["llm"]["displayName"] == "GPT-4" @@ -1372,20 +1404,20 @@ def test_build_models_config_all_failures(self, service_mocks): # Assert assert isinstance(result, dict) # All model types should still be present with empty configs - assert len(result) == 7 + assert len(result) == 9 # All configs should be empty due to exceptions - for model_key in ["llm", "embedding", "multiEmbedding", "rerank", "vlm", "stt", "tts"]: + for model_key in ["llm", "embedding", "multiEmbedding", "rerank", "vlm", "vlm2", "vlm3", "stt", "tts"]: assert result[model_key]["name"] == "" assert result[model_key]["displayName"] == "" assert result[model_key]["apiConfig"]["apiKey"] == "" assert result[model_key]["apiConfig"]["modelUrl"] == "" # Verify that logger.warning was called for each model type - assert service_mocks['logger'].warning.call_count == 7 + assert service_mocks['logger'].warning.call_count == 9 warning_calls = service_mocks['logger'].warning.call_args_list expected_configs = ["LLM_ID", "EMBEDDING_ID", "MULTI_EMBEDDING_ID", - "RERANK_ID", "VLM_ID", "STT_ID", "TTS_ID"] + "RERANK_ID", "VLM_ID", "VLM2_ID", "VLM3_ID", "STT_ID", "TTS_ID"] for i, config_key in enumerate(expected_configs): assert f"Failed to get config for {config_key}: Database completely down" in warning_calls[ i][0][0] diff --git a/test/backend/services/test_config_sync_service_voice.py b/test/backend/services/test_config_sync_service_voice.py index fcfd531f1..1a3144036 100644 --- a/test/backend/services/test_config_sync_service_voice.py +++ b/test/backend/services/test_config_sync_service_voice.py @@ -3,6 +3,8 @@ These tests cover the STT specific fields in save_config_impl. """ import sys +import types +import importlib from unittest.mock import patch, MagicMock import pytest @@ -22,6 +24,31 @@ minio_config_mock = MagicMock() minio_config_mock.validate = MagicMock() +if 'consts.const' in sys.modules and not hasattr(sys.modules['consts.const'], 'APP_DESCRIPTION'): + sys.modules.pop('consts.const', None) +if 'consts' in sys.modules and not hasattr(sys.modules['consts'], '__path__'): + sys.modules.pop('consts', None) + +database_client_module = types.ModuleType('database.client') +database_client_module.MinioClient = MagicMock() +database_client_module.minio_client = minio_client_mock +database_client_module.as_dict = MagicMock(side_effect=lambda value: value) +database_client_module.db_client = MagicMock() +database_client_module.db_client.clean_string_values = MagicMock(side_effect=lambda value: value) +database_client_module.get_db_session = MagicMock() +sys.modules['database.client'] = database_client_module +database_package = sys.modules.get('database') or importlib.import_module('database') +setattr(database_package, 'client', database_client_module) +database_model_management_module = types.ModuleType('database.model_management_db') +database_model_management_module.get_model_by_model_id = MagicMock() +database_model_management_module.get_model_id_by_display_name = MagicMock() +database_model_management_module.get_model_records = MagicMock(return_value=[]) +sys.modules['database.model_management_db'] = database_model_management_module +setattr(database_package, 'model_management_db', database_model_management_module) +backend_database_client_module = sys.modules.get('backend.database.client') +if backend_database_client_module is not None and not hasattr(backend_database_client_module, 'minio_client'): + backend_database_client_module.minio_client = minio_client_mock + patch('nexent.storage.storage_client_factory.create_storage_client_from_config', return_value=storage_client_mock).start() patch('nexent.storage.minio_config.MinIOStorageConfig', @@ -29,7 +56,7 @@ patch('backend.database.client.MinioClient', return_value=minio_client_mock).start() patch('database.client.MinioClient', return_value=minio_client_mock).start() -patch('backend.database.client.minio_client', minio_client_mock).start() +patch('backend.database.client.minio_client', minio_client_mock, create=True).start() patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start() # Import backend modules after all patches are applied @@ -47,14 +74,17 @@ def service_mocks(): with patch('backend.services.config_sync_service.tenant_config_manager') as mock_tenant_config_manager, \ patch('backend.services.config_sync_service.get_env_key') as mock_get_env_key, \ patch('backend.services.config_sync_service.safe_value') as mock_safe_value, \ + patch('backend.services.config_sync_service.get_model_records') as mock_get_model_records, \ patch('backend.services.config_sync_service.get_model_id_by_display_name') as mock_get_model_id, \ patch('backend.services.config_sync_service.get_model_name_from_config') as mock_get_model_name, \ patch('backend.services.config_sync_service.logger') as mock_logger: + mock_get_model_records.return_value = [] yield { 'tenant_config_manager': mock_tenant_config_manager, 'get_env_key': mock_get_env_key, 'safe_value': mock_safe_value, + 'get_model_records': mock_get_model_records, 'get_model_id': mock_get_model_id, 'get_model_name': mock_get_model_name, 'logger': mock_logger diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index bd56cce40..f85cbfe05 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -315,6 +315,11 @@ def _add_repo_to_name(model_repo, model_name): def import_svc(): """Import service under MinioClient patch to avoid real initialization.""" minio_client_mock = mock.MagicMock() + sys.modules["database"] = database_mod + sys.modules["database.model_management_db"] = db_mm_mod + setattr(database_mod, "model_management_db", db_mm_mod) + sys.modules.pop("backend.services.model_management_service", None) + sys.modules.pop("services.model_management_service", None) with mock.patch("backend.database.client.MinioClient", return_value=minio_client_mock): from backend.services import model_management_service as svc # type: ignore return svc @@ -727,7 +732,7 @@ async def test_batch_create_models_max_tokens_update(): with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing), \ mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", side_effect=lambda x: ("silicon", x.split("/")[1] if "/" in x else x)), \ - mock.patch.object(svc, "add_repo_to_name", side_effect=lambda r, n: f"{r}/{n}"), \ + mock.patch.object(svc, "add_repo_to_name", side_effect=lambda *args, **kwargs: f"{kwargs.get('model_repo', args[0] if args else '')}/{kwargs.get('model_name', args[1] if len(args) > 1 else '')}"), \ mock.patch.object(svc, "update_model_record") as mock_update, \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ mock.patch.object(svc, "create_model_record", return_value=True): From 37b87c4e3d201bd7f641d3d9e4ef30d1559e9ca7 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Wed, 27 May 2026 10:57:41 +0800 Subject: [PATCH 3/7] Chore: Update .dockerignore to remove unnecessary backend assets --- .dockerignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.dockerignore b/.dockerignore index 45c1def32..385a6449f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -37,8 +37,6 @@ build/ *.tgz # Backend -backend/assets/* -!backend/assets/test.wav backend/flower_db.sqlite uploads/ test/ @@ -60,4 +58,4 @@ assets/ .Spotlight-V100 .Trashes ehthumbs.db -Thumbs.db \ No newline at end of file +Thumbs.db From 6986c45f5544a4144e4b1db5211a09fa9d85a137 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Wed, 27 May 2026 11:25:15 +0800 Subject: [PATCH 4/7] Feat: add optional id field to SingleModelConfig interface --- frontend/types/modelConfig.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts index d5df7459d..8f4789f6b 100644 --- a/frontend/types/modelConfig.ts +++ b/frontend/types/modelConfig.ts @@ -91,6 +91,7 @@ export interface TTSModelConfig extends SingleModelConfig { // Single model configuration interface export interface SingleModelConfig { + id?: number; modelName: string; displayName: string; apiConfig: ModelApiConfig; From c55c33c8f45b38804125d1e8d40286915027ad97 Mon Sep 17 00:00:00 2001 From: 827dls <1670704430@qq.com> Date: Wed, 27 May 2026 23:06:50 +0800 Subject: [PATCH 5/7] fix_audio_tools,Support adding models individually --- backend/services/model_health_service.py | 42 +++++++++ backend/services/model_management_service.py | 22 ++++- .../services/providers/dashscope_provider.py | 91 ++++++++++++++++++- .../services/providers/tokenpony_provider.py | 76 +++++++++++++++- 4 files changed, 221 insertions(+), 10 deletions(-) diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 58a0af91f..09d37fa0d 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -15,6 +15,11 @@ logger = logging.getLogger("model_health_service") +DASHSCOPE_MODEL_FACTORY = "dashscope" +TOKENPONY_MODEL_FACTORY = "tokenpony" +PROVIDER_CATALOG_HEALTHCHECK_FACTORIES = {DASHSCOPE_MODEL_FACTORY, TOKENPONY_MODEL_FACTORY} +PROVIDER_CATALOG_HEALTHCHECK_TYPES = {"vlm", "vlm2", "vlm3"} + def _mask_secret(value: Optional[str]) -> str: """Mask a secret value, showing only first and last 4 characters.""" @@ -64,6 +69,31 @@ async def _embedding_dimension_check( raise ValueError(f"Unsupported model type: {model_type}") +async def _provider_catalog_connectivity_check( + model_name: str, + model_type: str, + model_api_key: str, + model_factory: Optional[str], +) -> bool: + """Validate provider-managed multimodal models through their model catalog.""" + provider = (model_factory or "").lower() + if provider not in PROVIDER_CATALOG_HEALTHCHECK_FACTORIES: + return False + + from services.model_provider_service import get_provider_models + + model_list = await get_provider_models({ + "provider": provider, + "model_type": model_type, + "api_key": model_api_key, + }) + if not model_list or any(model.get("_error") for model in model_list): + return False + + expected_model_id = model_name.lower() + return any(str(model.get("id", "")).lower() == expected_model_id for model in model_list) + + async def _perform_connectivity_check( model_name: str, model_type: str, @@ -135,6 +165,18 @@ async def _perform_connectivity_check( ) connectivity = await rerank_model.connectivity_check() elif model_type in ("vlm", "vlm2", "vlm3"): + if ( + model_type in PROVIDER_CATALOG_HEALTHCHECK_TYPES + and (model_factory or "").lower() in PROVIDER_CATALOG_HEALTHCHECK_FACTORIES + ): + connectivity = await _provider_catalog_connectivity_check( + model_name=model_name, + model_type=model_type, + model_api_key=model_api_key, + model_factory=model_factory, + ) + return connectivity + observer = MessageObserver() set_monitoring_operation("connectivity_check", display_name=display_name) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 8f6d191fd..9f032728a 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -8,7 +8,6 @@ from database.model_management_db import ( create_model_record, delete_model_record, - get_model_by_display_name, get_model_by_name_factory, get_models_by_display_name, get_model_records, @@ -32,6 +31,23 @@ logger = logging.getLogger("model_management_service") +INDEPENDENT_MULTIMODAL_MODEL_TYPES = {"vlm", "vlm2", "vlm3"} + + +def _has_display_name_conflict(existing_models: List[Dict[str, Any]], model_type: Optional[str]) -> bool: + """Allow the three multimodal slots to share display names across slots.""" + if not existing_models: + return False + + if model_type in INDEPENDENT_MULTIMODAL_MODEL_TYPES: + return any( + existing.get("model_type") == model_type + or existing.get("model_type") not in INDEPENDENT_MULTIMODAL_MODEL_TYPES + for existing in existing_models + ) + + return True + async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict[str, Any]): """Create a single model record for the given tenant. @@ -77,9 +93,9 @@ async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict # Check display name conflict scoped by tenant if model_data.get("display_name"): - existing_model_by_display = get_model_by_display_name( + existing_models_by_display = get_models_by_display_name( model_data["display_name"], tenant_id) - if existing_model_by_display: + if _has_display_name_conflict(existing_models_by_display, model_data.get("model_type")): logging.error( f"Name {model_data['display_name']} is already in use, please choose another display name") raise ValueError( diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index 69096fb15..497dcfe99 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -6,6 +6,75 @@ from services.providers.base import AbstractModelProvider, _classify_provider_error +DASHSCOPE_IMAGE_GENERATION_KEYWORDS = ( + "image", + "wanx", + "aitryon", + "tryon", + "flux", + "stable-diffusion", + "sdxl", +) +DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS = ( + "qwen-vl", + "qwen2-vl", + "qwen2.5-vl", + "qwen3-vl", + "qwen3.5-vl", + "qwen3.6-vl", + "-vl", + "vl-", + "vision", + "visual", + "ocr", + "qwen3.6", + "qwen-3.6", +) +DASHSCOPE_VIDEO_UNDERSTANDING_KEYWORDS = ("omni", "video-understanding", "video-ocr") + + +def _modality_set(value) -> set: + if not value: + return set() + if isinstance(value, str): + return {value.lower()} + return {str(item).lower() for item in value} + + +def _has_keyword(text: str, keywords: tuple) -> bool: + return any(keyword in text for keyword in keywords) + + +def _is_dashscope_explicit_image_understanding_model(model_id: str) -> bool: + return _has_keyword(model_id, DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS) + + +def _is_dashscope_image_generation_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool: + if _is_dashscope_explicit_image_understanding_model(model_id): + return False + return "image" in res_mods or _has_keyword(model_id, DASHSCOPE_IMAGE_GENERATION_KEYWORDS) + + +def _is_dashscope_video_understanding_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool: + searchable_text = f"{model_id} {desc.lower()}" + if "video" in req_mods and "text" in res_mods: + return True + return _has_keyword(searchable_text, DASHSCOPE_VIDEO_UNDERSTANDING_KEYWORDS) + + +def _is_dashscope_image_understanding_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool: + searchable_text = f"{model_id} {desc.lower()}" + if _is_dashscope_image_generation_model(model_id, desc, req_mods, res_mods): + return False + if _is_dashscope_video_understanding_model(model_id, desc, req_mods, res_mods): + return False + if ("image" in req_mods or "video" in req_mods) and "text" in res_mods: + return True + return _is_dashscope_explicit_image_understanding_model(model_id) or _has_keyword( + searchable_text, DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS + ) + + class DashScopeModelProvider(AbstractModelProvider): """Concrete implementation for DashScope (Aliyun) provider.""" @@ -57,6 +126,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: categorized_models = { "chat": [], # Maps to "llm" "vlm": [], # Maps to "vlm" + "vlm2": [], # Maps to image generation models + "vlm3": [], # Maps to video understanding models "embedding": [], # Maps to "embedding" / "multi_embedding" "rerank": [], # Maps to "rerank" "tts": [], # Maps to "tts" @@ -71,6 +142,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: metadata = model_obj.get('inference_metadata') or {} req_mod = metadata.get('request_modality', []) res_mod = metadata.get('response_modality', []) + req_mods = _modality_set(req_mod) + res_mods = _modality_set(res_mod) model_obj.setdefault("object", model_obj.get("object", "model")) model_obj.setdefault("owned_by", model_obj.get("owned_by", "dashscope")) cleaned_model = { @@ -107,8 +180,17 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: continue # 5. VLM - vision_mods = {'Image', 'Video'} - if (set(req_mod) & vision_mods) or (set(res_mod) & vision_mods) or '视觉' in desc: + if _is_dashscope_video_understanding_model(m_id, desc, req_mods, res_mods): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm3"}) + categorized_models['vlm3'].append(cleaned_model) + continue + + if _is_dashscope_image_generation_model(m_id, desc, req_mods, res_mods): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm2"}) + categorized_models['vlm2'].append(cleaned_model) + continue + + if _is_dashscope_image_understanding_model(m_id, desc, req_mods, res_mods): cleaned_model.update({"model_tag": "chat", "model_type": "vlm"}) categorized_models['vlm'].append(cleaned_model) continue @@ -124,7 +206,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: elif target_model_type in ("embedding", "multi_embedding"): return categorized_models["embedding"] elif target_model_type in categorized_models: - return categorized_models[target_model_type] + return [ + {**model, "model_type": target_model_type} + for model in categorized_models[target_model_type] + ] else: return [] except (httpx.HTTPStatusError, httpx.ConnectTimeout, httpx.ConnectError, Exception) as e: diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index ab4446c1b..be2bb9c71 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -9,6 +9,64 @@ from services.providers.base import AbstractModelProvider, _classify_provider_error +TOKENPONY_IMAGE_UNDERSTANDING_KEYWORDS = ( + "qwen-vl", + "qwen2-vl", + "qwen2.5-vl", + "qwen3-vl", + "qwen3.5-vl", + "qwen3.6-vl", + "-vl", + "vl-", + "vision", + "visual", + "ocr", + "gpt-4o", + "qwen3.6", + "qwen-3.6", +) +TOKENPONY_IMAGE_GENERATION_KEYWORDS = ( + "image", + "dall", + "flux", + "stable-diffusion", + "sdxl", + "midjourney", + "wanx", + "kolors", + "seedream", + "ideogram", + "recraft", +) +TOKENPONY_VIDEO_UNDERSTANDING_KEYWORDS = ("omni", "video") + + +def _has_keyword(text: str, keywords: tuple) -> bool: + return any(keyword in text for keyword in keywords) + + +def _is_tokenpony_explicit_image_understanding_model(model_id: str) -> bool: + return _has_keyword(model_id, TOKENPONY_IMAGE_UNDERSTANDING_KEYWORDS) + + +def _is_tokenpony_image_generation_model(model_id: str) -> bool: + if _is_tokenpony_explicit_image_understanding_model(model_id): + return False + return _has_keyword(model_id, TOKENPONY_IMAGE_GENERATION_KEYWORDS) + + +def _is_tokenpony_video_understanding_model(model_id: str) -> bool: + return _has_keyword(model_id, TOKENPONY_VIDEO_UNDERSTANDING_KEYWORDS) + + +def _is_tokenpony_image_understanding_model(model_id: str) -> bool: + if _is_tokenpony_image_generation_model(model_id): + return False + if _is_tokenpony_video_understanding_model(model_id): + return False + return _is_tokenpony_explicit_image_understanding_model(model_id) + + class TokenPonyModelProvider(AbstractModelProvider): """Concrete implementation for TokenPony provider.""" @@ -46,6 +104,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: categorized_models = { "chat": [], # Maps to "llm" "vlm": [], # Maps to "vlm" + "vlm2": [], # Maps to image generation models + "vlm3": [], # Maps to video understanding models "embedding": [], # Maps to "embedding" / "multi_embedding" "rerank": [], # Maps to "rerank" "tts": [], # Maps to "tts" @@ -86,9 +146,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: cleaned_model.update({"model_tag": "tts", "model_type": "tts"}) categorized_models['tts'].append(cleaned_model) - # 5. VLM (Vision Language Model / Image & Video Generation) - - elif any(keyword in m_id for keyword in ['-vl', 'vl-', 'ocr', 'vision']): + # 5. Multimodal models + elif _is_tokenpony_video_understanding_model(m_id): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm3"}) + categorized_models['vlm3'].append(cleaned_model) + elif _is_tokenpony_image_generation_model(m_id): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm2"}) + categorized_models['vlm2'].append(cleaned_model) + elif _is_tokenpony_image_understanding_model(m_id): cleaned_model.update({"model_tag": "chat", "model_type": "vlm"}) categorized_models['vlm'].append(cleaned_model) @@ -104,7 +169,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: elif target_model_type in ("embedding", "multi_embedding"): return categorized_models["embedding"] elif target_model_type in categorized_models: - return categorized_models[target_model_type] + return [ + {**model, "model_type": target_model_type} + for model in categorized_models[target_model_type] + ] else: return [] From e74adb7dec1220d2e0e937116ce6e7531b178a3e Mon Sep 17 00:00:00 2001 From: 827dls <1670704430@qq.com> Date: Thu, 28 May 2026 00:29:29 +0800 Subject: [PATCH 6/7] fix test --- sdk/nexent/core/tools/analyze_audio_tool.py | 62 ++++---- sdk/nexent/core/tools/analyze_video_tool.py | 72 +++++---- .../providers/test_dashscope_provider.py | 140 +++++++++++++++++- .../providers/test_tokenpony_provider.py | 124 +++++++++++++++- .../services/test_model_health_service.py | 46 ++++++ .../services/test_model_management_service.py | 136 +++++++++++++++-- .../tools/test_analyze_audio_video_tool.py | 58 +++++++- 7 files changed, 560 insertions(+), 78 deletions(-) diff --git a/sdk/nexent/core/tools/analyze_audio_tool.py b/sdk/nexent/core/tools/analyze_audio_tool.py index c7509a6c2..1e5439443 100644 --- a/sdk/nexent/core/tools/analyze_audio_tool.py +++ b/sdk/nexent/core/tools/analyze_audio_tool.py @@ -7,7 +7,7 @@ import logging from io import BytesIO -from typing import List +from typing import List, Optional from jinja2 import StrictUndefined, Template from pydantic import Field @@ -28,28 +28,29 @@ class AnalyzeAudioTool(Tool): """Tool for understanding and analyzing audio using the video understanding model.""" name = "analyze_audio" + skip_forward_signature_validation = True description = ( "This tool uses the configured video understanding model to understand audio based on your query and then returns an audio analysis result.\n" - "It is used to understand and analyze multiple audio files, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " + "It is used to understand and analyze one audio file, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " "HTTP, and HTTPS URLs.\n" "Use this tool when you want to retrieve information contained in audio and provide the audio URL and your query." ) description_zh = ( - "使用视频理解模型,根据你的提示词来理解音频,并返回音频分析结果。" - "可用于理解和分析多个音频文件,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + "使用视频理解模型,根据你的问题理解音频,并返回音频分析结果。" + "可用于理解和分析一个音频文件,支持 S3 URL(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" ) inputs = { - "audio_urls_list": { - "type": "array", - "description": "List of audio URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", - "description_zh": "列表形式输入音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。" + "audio_url": { + "type": "string", + "description": "Audio URL (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", + "description_zh": "音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。", }, "query": { "type": "string", "description": "User's question to guide the audio analysis", - "description_zh": "用户用于指导音频分析的问题" - } + "description_zh": "用户的问题,用于指导音频分析", + }, } init_param_descriptions = { @@ -58,9 +59,9 @@ class AnalyzeAudioTool(Tool): "storage_client": {"description": "Storage client for downloading files"}, "validate_url_access": { "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" - } + }, } - output_type = "array" + output_type = "string" category = ToolCategory.MULTIMODAL.value tool_sign = ToolSign.MULTIMODAL_OPERATION.value @@ -94,10 +95,10 @@ def __init__( validate_callback = validate_url_access self.mm = LoadSaveObjectManager( storage_client=self.storage_client, - validate_url_access=validate_callback + validate_url_access=validate_callback, ) self.forward = self.mm.load_object( - input_names=["audio_urls_list"])(self._forward_impl) + input_names=["audio_url", "audio_urls_list"])(self._forward_impl) self.running_prompt_zh = "正在分析音频..." self.running_prompt_en = "Analyzing audio..." @@ -114,10 +115,14 @@ def _validate_audio_capable_model(self) -> None: "Please choose a Qwen3-Omni model for analyze_audio." ) - def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: - """Analyze audio files and return one result per audio input.""" + def _forward_impl( + self, + audio_url: Optional[bytes] = None, + query: str = "", + audio_urls_list: Optional[List[bytes]] = None) -> str: + """Analyze an audio file and return the result as a string.""" if self.vlm_model is None: - error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试" + error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试。" error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again." error_msg = error_msg_zh if self._is_chinese else error_msg_en logger.error(error_msg) @@ -128,12 +133,17 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en self.observer.add_message("", ProcessType.TOOL, running_prompt) - if audio_urls_list is None: - raise ValueError("audio_urls cannot be None") - if not isinstance(audio_urls_list, list): - raise ValueError("audio_urls must be a list of bytes") - if not audio_urls_list: - raise ValueError("audio_urls must contain at least one audio file") + if audio_url is not None: + audio_items = [audio_url] + else: + audio_items = audio_urls_list + + if audio_items is None: + raise ValueError("audio_url cannot be None") + if not isinstance(audio_items, list): + raise ValueError("audio_url must be bytes or audio_urls_list must be a list of bytes") + if not audio_items: + raise ValueError("audio_url must contain an audio file") language = self.observer.lang if self.observer else "en" prompts = get_prompt_template( @@ -143,7 +153,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: try: analysis_results: List[str] = [] - for index, audio_bytes in enumerate(audio_urls_list, start=1): + for index, audio_bytes in enumerate(audio_items, start=1): logger.info(f"Analyzing audio #{index}, query: {query}") content_type = detect_content_type_from_bytes(audio_bytes) if not content_type.startswith("audio/"): @@ -153,7 +163,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: response = self.vlm_model.analyze_audio( audio_input=audio_stream, system_prompt=system_prompt, - content_type=content_type + content_type=content_type, ) except Exception as e: error_msg_zh = f"音频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。" @@ -163,7 +173,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: analysis_results.append(response.content) - return analysis_results + return "\n\n".join(analysis_results) except Exception as e: logger.error(f"Error analyzing audio: {str(e)}", exc_info=True) raise Exception(f"Error analyzing audio: {str(e)}") diff --git a/sdk/nexent/core/tools/analyze_video_tool.py b/sdk/nexent/core/tools/analyze_video_tool.py index 3dc033551..e7bf84549 100644 --- a/sdk/nexent/core/tools/analyze_video_tool.py +++ b/sdk/nexent/core/tools/analyze_video_tool.py @@ -1,13 +1,13 @@ """ Analyze Video Tool -Analyze videos using the configured video understanding model. -Supports videos from S3, HTTP, and HTTPS URLs. +Analyze video using the configured video understanding model. +Supports video from S3, HTTP, and HTTPS URLs. """ import logging from io import BytesIO -from typing import List +from typing import List, Optional from jinja2 import StrictUndefined, Template from pydantic import Field @@ -25,31 +25,32 @@ class AnalyzeVideoTool(Tool): - """Tool for understanding and analyzing videos using the video understanding model.""" + """Tool for understanding and analyzing video using the video understanding model.""" name = "analyze_video" + skip_forward_signature_validation = True description = ( - "This tool uses the configured video understanding model to understand videos based on your query and then returns a video analysis result.\n" - "It is used to understand and analyze multiple videos, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " + "This tool uses the configured video understanding model to understand video based on your query and then returns a video analysis result.\n" + "It is used to understand and analyze one video, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " "HTTP, and HTTPS URLs.\n" - "Use this tool when you want to retrieve information contained in a video and provide the video's URL and your query." + "Use this tool when you want to retrieve information contained in a video and provide the video URL and your query." ) description_zh = ( - "使用视频理解模型,根据你的提示词来理解视频,并返回视频分析结果。" - "可用于理解和分析多个视频,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + "使用视频理解模型,根据你的问题理解视频,并返回视频分析结果。" + "可用于理解和分析一个视频,支持 S3 URL(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" ) inputs = { - "video_urls_list": { - "type": "array", - "description": "List of video URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", - "description_zh": "列表形式输入视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。" + "video_url": { + "type": "string", + "description": "Video URL (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", + "description_zh": "视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。", }, "query": { "type": "string", "description": "User's question to guide the video analysis", - "description_zh": "用户用于指导视频分析的问题" - } + "description_zh": "用户的问题,用于指导视频分析", + }, } init_param_descriptions = { @@ -58,9 +59,9 @@ class AnalyzeVideoTool(Tool): "storage_client": {"description": "Storage client for downloading files"}, "validate_url_access": { "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" - } + }, } - output_type = "array" + output_type = "string" category = ToolCategory.MULTIMODAL.value tool_sign = ToolSign.MULTIMODAL_OPERATION.value @@ -94,18 +95,22 @@ def __init__( validate_callback = validate_url_access self.mm = LoadSaveObjectManager( storage_client=self.storage_client, - validate_url_access=validate_callback + validate_url_access=validate_callback, ) self.forward = self.mm.load_object( - input_names=["video_urls_list"])(self._forward_impl) + input_names=["video_url", "video_urls_list"])(self._forward_impl) self.running_prompt_zh = "正在分析视频..." self.running_prompt_en = "Analyzing video..." - def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: - """Analyze videos and return one result per video input.""" + def _forward_impl( + self, + video_url: Optional[bytes] = None, + query: str = "", + video_urls_list: Optional[List[bytes]] = None) -> str: + """Analyze a video and return the result as a string.""" if self.vlm_model is None: - error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试" + error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试。" error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again." error_msg = error_msg_zh if self._is_chinese else error_msg_en logger.error(error_msg) @@ -115,12 +120,17 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en self.observer.add_message("", ProcessType.TOOL, running_prompt) - if video_urls_list is None: - raise ValueError("video_urls cannot be None") - if not isinstance(video_urls_list, list): - raise ValueError("video_urls must be a list of bytes") - if not video_urls_list: - raise ValueError("video_urls must contain at least one video") + if video_url is not None: + video_items = [video_url] + else: + video_items = video_urls_list + + if video_items is None: + raise ValueError("video_url cannot be None") + if not isinstance(video_items, list): + raise ValueError("video_url must be bytes or video_urls_list must be a list of bytes") + if not video_items: + raise ValueError("video_url must contain a video") language = self.observer.lang if self.observer else "en" prompts = get_prompt_template( @@ -130,7 +140,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: try: analysis_results: List[str] = [] - for index, video_bytes in enumerate(video_urls_list, start=1): + for index, video_bytes in enumerate(video_items, start=1): logger.info(f"Analyzing video #{index}, query: {query}") content_type = detect_content_type_from_bytes(video_bytes) if not content_type.startswith("video/"): @@ -140,7 +150,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: response = self.vlm_model.analyze_video( video_input=video_stream, system_prompt=system_prompt, - content_type=content_type + content_type=content_type, ) except Exception as e: error_msg_zh = f"视频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。" @@ -150,7 +160,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: analysis_results.append(response.content) - return analysis_results + return "\n\n".join(analysis_results) except Exception as e: logger.error(f"Error analyzing video: {str(e)}", exc_info=True) raise Exception(f"Error analyzing video: {str(e)}") diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py index 30229677a..5c6267040 100644 --- a/test/backend/services/providers/test_dashscope_provider.py +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -141,11 +141,27 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): "models": [ { "model": "qwen-vl-plus", - "description": "Vision language model", + "description": "Vision language model for image understanding", "inference_metadata": { "request_modality": ["Image", "Text"], "response_modality": ["Text"] } + }, + { + "model": "qwen3.6-27b", + "description": "Qwen 3.6 multimodal model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen-vl-max", + "description": "Qwen VL max model", + "inference_metadata": { + "request_modality": ["Image", "Text"], + "response_modality": ["Text", "Image"] + } } ] } @@ -167,11 +183,129 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): result = await provider.get_models(provider_config) + assert [model["id"] for model in result] == ["qwen-vl-plus", "qwen3.6-27b", "qwen-vl-max"] + assert all(model["model_type"] == "vlm" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + + @pytest.mark.asyncio + async def test_get_models_vlm2_only_returns_image_generation_models(self, mocker: MockFixture): + """Image generation slot only returns image-generation multimodal models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-vl-plus", + "description": "Vision language model", + "inference_metadata": { + "request_modality": ["Image", "Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen-image-max", + "description": "Image generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Image"] + } + }, + { + "model": "qwen-vl-max", + "description": "Qwen VL max model", + "inference_metadata": { + "request_modality": ["Image", "Text"], + "response_modality": ["Text", "Image"] + } + }, + { + "model": "qwen-plus", + "description": "Text generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + self._setup_mock_client(mocker, mock_response) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "vlm2", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + assert len(result) == 1 - assert result[0]["id"] == "qwen-vl-plus" - assert result[0]["model_type"] == "vlm" + assert result[0]["id"] == "qwen-image-max" + assert result[0]["model_type"] == "vlm2" assert result[0]["model_tag"] == "chat" + @pytest.mark.asyncio + async def test_get_models_vlm3_only_returns_video_understanding_models(self, mocker: MockFixture): + """Video understanding slot excludes image generation and text-only models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-image-max", + "description": "Image generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Image"] + } + }, + { + "model": "qwen-omni-turbo", + "description": "Video understanding model", + "inference_metadata": { + "request_modality": ["Video", "Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen3-omni-30b-a3b-instruct", + "description": "Omni multimodal model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen-plus", + "description": "Text generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + self._setup_mock_client(mocker, mock_response) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "vlm3", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert [model["id"] for model in result] == ["qwen-omni-turbo", "qwen3-omni-30b-a3b-instruct"] + assert all(model["model_type"] == "vlm3" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + @pytest.mark.asyncio async def test_get_models_rerank_success(self, mocker: MockFixture): """Test successful model retrieval for rerank models.""" diff --git a/test/backend/services/providers/test_tokenpony_provider.py b/test/backend/services/providers/test_tokenpony_provider.py index e93d8ba7b..58e514dbb 100644 --- a/test/backend/services/providers/test_tokenpony_provider.py +++ b/test/backend/services/providers/test_tokenpony_provider.py @@ -126,6 +126,16 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): "id": "qwen-vl-plus", "object": "model", "owned_by": "qwen" + }, + { + "id": "qwen3.6-27b", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "qwen-vl-max", + "object": "model", + "owned_by": "qwen" } ] } @@ -155,11 +165,121 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): result = await provider.get_models(provider_config) + assert [model["id"] for model in result] == ["qwen-vl-plus", "qwen3.6-27b", "qwen-vl-max"] + assert all(model["model_type"] == "vlm" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + + @pytest.mark.asyncio + async def test_get_models_vlm2_only_returns_image_generation_models(self, mocker: MockFixture): + """Image generation slot only returns image-generation multimodal models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "qwen-vl-plus", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "flux-image-pro", + "object": "model", + "owned_by": "flux" + }, + { + "id": "qwen-vl-max", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "qwen-plus", + "object": "model", + "owned_by": "qwen" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "vlm2", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + assert len(result) == 1 - assert result[0]["id"] == "qwen-vl-plus" - assert result[0]["model_type"] == "vlm" + assert result[0]["id"] == "flux-image-pro" + assert result[0]["model_type"] == "vlm2" assert result[0]["model_tag"] == "chat" + @pytest.mark.asyncio + async def test_get_models_vlm3_only_returns_video_understanding_models(self, mocker: MockFixture): + """Video understanding slot excludes image generation and text-only models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "flux-image-pro", + "object": "model", + "owned_by": "flux" + }, + { + "id": "qwen-omni-video", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "qwen3-omni-30b-a3b-instruct", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "qwen-plus", + "object": "model", + "owned_by": "qwen" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "vlm3", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert [model["id"] for model in result] == ["qwen-omni-video", "qwen3-omni-30b-a3b-instruct"] + assert all(model["model_type"] == "vlm3" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + @pytest.mark.asyncio async def test_get_models_rerank_success(self, mocker: MockFixture): """Test successful model retrieval for rerank models.""" diff --git a/test/backend/services/test_model_health_service.py b/test/backend/services/test_model_health_service.py index 139ded1a2..559677527 100644 --- a/test/backend/services/test_model_health_service.py +++ b/test/backend/services/test_model_health_service.py @@ -24,6 +24,7 @@ def __getattr__(cls, key): sys.modules['utils'] = MockModule() sys.modules['utils.auth_utils'] = MockModule() sys.modules['utils.config_utils'] = MockModule() +sys.modules['utils.memory_utils'] = MockModule() sys.modules['utils.model_name_utils'] = MockModule() sys.modules['consts'] = MockModule() consts_const_module = MockModule() @@ -221,6 +222,51 @@ async def test_perform_connectivity_check_vlm(): mock_model_instance.check_connectivity.assert_called_once() +@pytest.mark.asyncio +async def test_perform_connectivity_check_dashscope_multimodal_uses_provider_catalog(): + model_provider_service = types.ModuleType("services.model_provider_service") + model_provider_service.get_provider_models = mock.AsyncMock(return_value=[ + {"id": "qwen-image-max", "model_type": "vlm2"}, + ]) + + with mock.patch.dict(sys.modules, {"services.model_provider_service": model_provider_service}), \ + mock.patch("backend.services.model_health_service.OpenAIVLModel") as mock_model: + result = await _perform_connectivity_check( + "qwen-image-max", + "vlm2", + "https://dashscope.aliyuncs.com/compatible-mode/v1/", + "test-key", + model_factory="dashscope", + ) + + assert result is True + model_provider_service.get_provider_models.assert_awaited_once_with({ + "provider": "dashscope", + "model_type": "vlm2", + "api_key": "test-key", + }) + mock_model.assert_not_called() + + +@pytest.mark.asyncio +async def test_perform_connectivity_check_tokenpony_multimodal_catalog_error_returns_false(): + model_provider_service = types.ModuleType("services.model_provider_service") + model_provider_service.get_provider_models = mock.AsyncMock(return_value=[ + {"_error": "authentication_failed", "_message": "Invalid API key"}, + ]) + + with mock.patch.dict(sys.modules, {"services.model_provider_service": model_provider_service}): + result = await _perform_connectivity_check( + "qwen-vl-plus", + "vlm3", + "https://api.tokenpony.cn/v1/", + "bad-key", + model_factory="tokenpony", + ) + + assert result is False + + @pytest.mark.asyncio async def test_perform_connectivity_check_stt(): # Setup diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 83a070fe0..be9c5e406 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -83,7 +83,22 @@ def get_value(status): return status or _ModelConnectStatusEnum.NOT_DETECTED.value +class _ToolValidateRequest: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class _ProcessParams: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def model_dump(self, *args, **kwargs): + return dict(self.__dict__) + + consts_model_mod.ModelConnectStatusEnum = _ModelConnectStatusEnum +consts_model_mod.ToolValidateRequest = _ToolValidateRequest +consts_model_mod.ProcessParams = _ProcessParams sys.modules["consts.model"] = consts_model_mod if "consts" not in sys.modules: sys.modules["consts"] = types.ModuleType("consts") @@ -93,6 +108,23 @@ def get_value(status): consts_const_mod.LOCALHOST_IP = "127.0.0.1" consts_const_mod.LOCALHOST_NAME = "localhost" consts_const_mod.DOCKER_INTERNAL_HOST = "host.docker.internal" +consts_const_mod.DATA_PROCESS_SERVICE = "http://data-process" +consts_const_mod.FILE_PREVIEW_SIZE_LIMIT = 100 * 1024 * 1024 +consts_const_mod.MAX_CONCURRENT_UPLOADS = 5 +consts_const_mod.OFFICE_MIME_TYPES = [] +consts_const_mod.UPLOAD_FOLDER = "uploads" +consts_const_mod.LOCAL_MCP_SERVER = "http://local-mcp" +consts_const_mod.MCP_MANAGEMENT_API = "http://mcp-management" +consts_const_mod.LIBREOFFICE_PROFILE_DIR = "libreoffice-profile" +consts_const_mod.DEFAULT_TENANT_ID = "tenant_id" +consts_const_mod.DEFAULT_USER_ID = "user_id" +consts_const_mod.IS_SPEED_MODE = False +consts_const_mod.SUPABASE_JWT_SECRET = "test-secret" +consts_const_mod.SUPABASE_URL = "http://supabase" +consts_const_mod.SUPABASE_KEY = "supabase-key" +consts_const_mod.SERVICE_ROLE_KEY = "service-role-key" +consts_const_mod.DEBUG_JWT_EXPIRE_SECONDS = 3600 +consts_const_mod.LANGUAGE = "zh" # Fields required by utils.memory_utils and services.vectordatabase_service consts_const_mod.MODEL_CONFIG_MAPPING = { "llm": "LLM_ID", "embedding": "EMBEDDING_ID"} @@ -193,9 +225,23 @@ def _sort_models_by_id(model_list): utils_name_mod.sort_models_by_id = _sort_models_by_id sys.modules["utils.model_name_utils"] = utils_name_mod +# Stub utils.file_management_utils so file_management_service can be imported +# by other tests in the same pytest process without pulling auth/database deps. +utils_file_mgmt_mod = types.ModuleType("utils.file_management_utils") + + +async def _save_upload_file(*args, **kwargs): + return None + + +utils_file_mgmt_mod.save_upload_file = _save_upload_file +sys.modules["utils.file_management_utils"] = utils_file_mgmt_mod + # Stub database.model_management_db to avoid importing heavy DB client database_mod = types.ModuleType("database") +database_mod.__path__ = [] db_mm_mod = types.ModuleType("database.model_management_db") +db_attachment_mod = types.ModuleType("database.attachment_db") def _noop(*args, **kwargs): @@ -245,6 +291,22 @@ def _get_model_by_model_id(model_id: int, tenant_id: str): db_mm_mod.update_model_record = _noop sys.modules["database"] = database_mod sys.modules["database.model_management_db"] = db_mm_mod +for _attachment_func in [ + "copy_file", + "delete_file", + "file_exists", + "get_content_type", + "get_file_range", + "get_file_size_from_minio", + "get_file_stream", + "get_file_stream_raw", + "get_file_url", + "list_files", + "upload_fileobj", +]: + setattr(db_attachment_mod, _attachment_func, _noop) +sys.modules["database.attachment_db"] = db_attachment_mod +setattr(database_mod, "attachment_db", db_attachment_mod) # Stub database.tenant_config_db required by utils.config_utils db_tenant_cfg_mod = types.ModuleType("database.tenant_config_db") @@ -286,10 +348,15 @@ def _update_config_by_tenant_config_id(*args, **kwargs): services_vdb_mod = types.ModuleType("services.vectordatabase_service") +class _ElasticSearchService: + pass + + def _get_vector_db_core(): return object() +services_vdb_mod.ElasticSearchService = _ElasticSearchService services_vdb_mod.get_vector_db_core = _get_vector_db_core sys.modules["services.vectordatabase_service"] = services_vdb_mod @@ -340,7 +407,7 @@ def import_svc(): async def test_create_model_for_tenant_success_llm(): svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None) as mock_get_by_display, \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]) as mock_get_by_display, \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("huggingface", "llama")): @@ -367,7 +434,7 @@ async def test_create_model_for_tenant_open_router_disables_ssl(): """When base_url contains 'open/router' ssl_verify should be set to False and model_factory to 'modelengine'.""" svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("modelengine", "m")): @@ -394,7 +461,7 @@ async def test_create_model_for_tenant_open_router_disables_ssl(): async def test_create_model_for_tenant_conflict_raises(): svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value={"model_id": "exists"}): + with mock.patch.object(svc, "get_models_by_display_name", return_value=[{"model_id": "exists", "model_type": "llm"}]): user_id = "u1" tenant_id = "t1" model_data = { @@ -415,7 +482,7 @@ async def test_create_model_for_tenant_display_name_conflict_valueerror(): svc = import_svc() existing_model = {"model_id": 1, "display_name": "existing_name"} - with mock.patch.object(svc, "get_model_by_display_name", return_value=existing_model): + with mock.patch.object(svc, "get_models_by_display_name", return_value=[existing_model]): user_id = "u1" tenant_id = "t1" model_data = { @@ -432,11 +499,58 @@ async def test_create_model_for_tenant_display_name_conflict_valueerror(): assert "existing_name" in str(exc.value) +@pytest.mark.asyncio +async def test_create_model_for_tenant_allows_same_display_name_across_multimodal_slots(): + """Image understanding, image generation, and video understanding are separate slots.""" + svc = import_svc() + + existing_models = [ + {"model_id": 1, "display_name": "Qwen3.6-27B", "model_type": "vlm"}, + {"model_id": 2, "display_name": "Qwen3.6-27B", "model_type": "vlm3"}, + ] + + with mock.patch.object(svc, "get_models_by_display_name", return_value=existing_models), \ + mock.patch.object(svc, "create_model_record") as mock_create, \ + mock.patch.object(svc, "split_repo_name", return_value=("Qwen", "Qwen3.6-27B")): + + model_data = { + "model_name": "Qwen/Qwen3.6-27B", + "display_name": "Qwen3.6-27B", + "base_url": "https://api.example.com/v1", + "model_type": "vlm2", + } + + await svc.create_model_for_tenant("u1", "t1", model_data) + + mock_create.assert_called_once() + + +@pytest.mark.asyncio +async def test_create_model_for_tenant_blocks_duplicate_within_same_multimodal_slot(): + svc = import_svc() + + with mock.patch.object( + svc, + "get_models_by_display_name", + return_value=[{"model_id": 1, "display_name": "Qwen3.6-27B", "model_type": "vlm"}], + ): + model_data = { + "model_name": "Qwen/Qwen3.6-27B", + "display_name": "Qwen3.6-27B", + "base_url": "https://api.example.com/v1", + "model_type": "vlm", + } + + with pytest.raises(Exception) as exc: + await svc.create_model_for_tenant("u1", "t1", model_data) + assert "already in use" in str(exc.value) + + @pytest.mark.asyncio async def test_create_model_for_tenant_multi_embedding_creates_two_records(): svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): @@ -458,7 +572,7 @@ async def test_create_model_for_tenant_multi_embedding_creates_two_records(): async def test_create_model_for_tenant_embedding_sets_dimension(): svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=1536)) as mock_dim, \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-ada-002")): @@ -484,7 +598,7 @@ async def test_create_model_for_tenant_embedding_sets_default_chunk_batch(): """chunk_batch defaults to 10 when not provided for embedding models.""" svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-3-small")): @@ -513,7 +627,7 @@ async def test_create_model_for_tenant_multi_embedding_sets_default_chunk_batch( """chunk_batch defaults to 10 when not provided for multi_embedding models (covers line 79).""" svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): @@ -591,7 +705,7 @@ async def test_batch_create_models_for_tenant_dashscope_provider(): mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", return_value=("qwen", "qwen-turbo")), \ mock.patch.object(svc, "add_repo_to_name", return_value="qwen/qwen-turbo"), \ - mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ mock.patch.object(svc, "create_model_record", return_value=True): @@ -617,7 +731,7 @@ async def test_batch_create_models_for_tenant_tokenpony_provider(): mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", return_value=("gpt", "gpt-4o")), \ mock.patch.object(svc, "add_repo_to_name", return_value="gpt/gpt-4o"), \ - mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 2})), \ mock.patch.object(svc, "create_model_record", return_value=True): @@ -650,7 +764,7 @@ async def test_batch_create_models_for_tenant_other_provider(): mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "gpt-4")), \ mock.patch.object(svc, "add_repo_to_name", return_value="openai/gpt-4"), \ - mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ mock.patch.object(svc, "create_model_record", return_value=True): diff --git a/test/sdk/core/tools/test_analyze_audio_video_tool.py b/test/sdk/core/tools/test_analyze_audio_video_tool.py index 94401b61d..7369ddfb2 100644 --- a/test/sdk/core/tools/test_analyze_audio_video_tool.py +++ b/test/sdk/core/tools/test_analyze_audio_video_tool.py @@ -44,9 +44,9 @@ def _fake_get_prompt(template_type, language=None, **_): storage_client=mock_storage_client, ) - result = tool._forward_impl([b"ID3audio-bytes"], "what happened?") + result = tool._forward_impl(audio_url=b"ID3audio-bytes", query="what happened?") - assert result == ["audio result"] + assert result == "audio result" assert calls == [("analyze_audio", "en")] mock_vlm_model.analyze_audio.assert_called_once() call_kwargs = mock_vlm_model.analyze_audio.call_args.kwargs @@ -55,6 +55,30 @@ def _fake_get_prompt(template_type, language=None, **_): observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing audio...") +def test_analyze_audio_schema_uses_single_url(): + assert "audio_url" in AnalyzeAudioTool.inputs + assert "audio_urls_list" not in AnalyzeAudioTool.inputs + assert AnalyzeAudioTool.output_type == "string" + + +def test_analyze_audio_accepts_legacy_url_list(observer_en, mock_vlm_model, mock_storage_client, monkeypatch): + monkeypatch.setattr( + analyze_audio_tool, + "get_prompt_template", + lambda template_type, language=None, **_: {"system_prompt": "Analyze audio for {{ query }}"}, + ) + mock_vlm_model.analyze_audio.return_value = SimpleNamespace(content="audio result") + tool = AnalyzeAudioTool( + observer=observer_en, + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + result = tool._forward_impl(audio_urls_list=[b"ID3audio-bytes"], query="what happened?") + + assert result == "audio result" + + def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_storage_client): vlm_model = SimpleNamespace( model_id="Qwen/Qwen3-VL-32B-Instruct", @@ -67,7 +91,7 @@ def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_stor ) with pytest.raises(ValueError) as exc_info: - tool._forward_impl([b"ID3audio-bytes"], "what happened?") + tool._forward_impl(audio_url=b"ID3audio-bytes", query="what happened?") assert "Please choose a Qwen3-Omni model" in str(exc_info.value) @@ -87,9 +111,9 @@ def _fake_get_prompt(template_type, language=None, **_): storage_client=mock_storage_client, ) - result = tool._forward_impl([b"\x00\x00\x00\x18ftypmp42video-bytes"], "what happened?") + result = tool._forward_impl(video_url=b"\x00\x00\x00\x18ftypmp42video-bytes", query="what happened?") - assert result == ["video result"] + assert result == "video result" assert calls == [("analyze_video", "en")] mock_vlm_model.analyze_video.assert_called_once() call_kwargs = mock_vlm_model.analyze_video.call_args.kwargs @@ -98,6 +122,30 @@ def _fake_get_prompt(template_type, language=None, **_): observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing video...") +def test_analyze_video_schema_uses_single_url(): + assert "video_url" in AnalyzeVideoTool.inputs + assert "video_urls_list" not in AnalyzeVideoTool.inputs + assert AnalyzeVideoTool.output_type == "string" + + +def test_analyze_video_accepts_legacy_url_list(observer_en, mock_vlm_model, mock_storage_client, monkeypatch): + monkeypatch.setattr( + analyze_video_tool, + "get_prompt_template", + lambda template_type, language=None, **_: {"system_prompt": "Analyze video for {{ query }}"}, + ) + mock_vlm_model.analyze_video.return_value = SimpleNamespace(content="video result") + tool = AnalyzeVideoTool( + observer=observer_en, + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + result = tool._forward_impl(video_urls_list=[b"\x00\x00\x00\x18ftypmp42video-bytes"], query="what happened?") + + assert result == "video result" + + @pytest.mark.parametrize( "tool_class,input_name,error_text", [ From 4aa9fa98af41d1270c8dfacacef30e611906f674 Mon Sep 17 00:00:00 2001 From: 827dls <1670704430@qq.com> Date: Thu, 28 May 2026 00:56:01 +0800 Subject: [PATCH 7/7] fix test 2 --- .../test_tool_configuration_service.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index ed6879c6e..29d2f75f6 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -3611,6 +3611,7 @@ def test_get_llm_model_success(self, mock_tenant_config, mock_get_model_name, mo api_key="test_api_key", max_context_tokens=4096, ssl_verify=True, + timeout_seconds=None, ) @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"}) @@ -3650,6 +3651,42 @@ def test_get_llm_model_with_missing_config_values(self, mock_tenant_config, mock call_kwargs = mock_openai_model.call_args[1] assert call_kwargs["api_key"] is None assert call_kwargs["max_context_tokens"] is None + assert call_kwargs["timeout_seconds"] is None + + @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"}) + @patch('backend.services.file_management_service.MessageObserver') + @patch('backend.services.file_management_service.OpenAILongContextModel') + @patch('backend.services.file_management_service.get_model_name_from_config') + @patch('backend.services.file_management_service.tenant_config_manager') + def test_get_llm_model_with_timeout_seconds(self, mock_tenant_config, mock_get_model_name, mock_openai_model, mock_message_observer): + """Test get_llm_model passes configured timeout_seconds.""" + from backend.services.file_management_service import get_llm_model + + mock_config = { + "base_url": "http://api.example.com", + "api_key": "test_api_key", + "max_tokens": 4096, + "timeout_seconds": 30, + } + mock_tenant_config.get_model_config.return_value = mock_config + mock_get_model_name.return_value = "gpt-4" + mock_observer_instance = Mock() + mock_message_observer.return_value = mock_observer_instance + mock_model_instance = Mock() + mock_openai_model.return_value = mock_model_instance + + result = get_llm_model("tenant123") + + assert result == mock_model_instance + mock_openai_model.assert_called_once_with( + observer=mock_observer_instance, + model_id="gpt-4", + api_base="http://api.example.com", + api_key="test_api_key", + max_context_tokens=4096, + ssl_verify=True, + timeout_seconds=30, + ) @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"}) @patch('backend.services.file_management_service.MessageObserver')