From fe748daca7ef5808e6c5f63fddb0647713425cc0 Mon Sep 17 00:00:00 2001
From: 827dls <1670704430@qq.com>
Date: Sun, 17 May 2026 16:46:16 +0800
Subject: [PATCH 1/7] =?UTF-8?q?=E8=A7=86=E8=A7=89=E8=AF=AD=E8=A8=80?=
=?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=88=86=E7=B1=BB=20=E4=BB=A5=E5=8F=8A?=
=?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=9F=B3=E9=A2=91=E7=90=86=E8=A7=A3=E5=B7=A5?=
=?UTF-8?q?=E5=85=B7?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/agents/create_agent_info.py | 62 +++++--
backend/consts/const.py | 2 +
backend/consts/model.py | 10 ++
.../managed_system_prompt_template_en.yaml | 2 +-
.../manager_system_prompt_template_en.yaml | 2 +-
backend/services/agent_service.py | 2 +-
backend/services/config_sync_service.py | 20 ++-
backend/services/image_service.py | 30 +++-
backend/services/model_health_service.py | 2 +-
backend/services/model_management_service.py | 18 +-
.../services/providers/silicon_provider.py | 76 +++++++-
.../services/tool_configuration_service.py | 20 ++-
docker/.env.bak | 168 -----------------
.../components/agentConfig/ToolManagement.tsx | 43 ++++-
.../chat/components/chatAttachment.tsx | 10 +-
.../[locale]/chat/components/chatInput.tsx | 22 ++-
.../[locale]/chat/internal/chatInterface.tsx | 16 +-
.../components/model/ModelAddDialog.tsx | 70 +++++++-
.../components/model/ModelDeleteDialog.tsx | 31 +++-
.../components/model/ModelEditDialog.tsx | 12 +-
.../models/components/modelConfig.tsx | 46 ++++-
frontend/const/chatConfig.ts | 14 +-
frontend/const/modelConfig.ts | 2 +
frontend/hooks/model/useDashscopeModelList.ts | 4 +-
frontend/hooks/model/useTokenponyModelList.ts | 4 +-
frontend/hooks/useConfig.ts | 44 ++++-
frontend/lib/chat/chatAttachmentUtils.ts | 13 ++
frontend/public/locales/en/common.json | 8 +-
frontend/public/locales/zh/common.json | 8 +-
...{tailwind.config.ts => tailwind.config.js} | 0
frontend/types/modelConfig.ts | 4 +
sdk/nexent/core/agents/nexent_agent.py | 4 +-
sdk/nexent/core/models/openai_vlm.py | 61 +++++++
sdk/nexent/core/prompts/analyze_audio_en.yaml | 13 ++
sdk/nexent/core/prompts/analyze_audio_zh.yaml | 13 ++
sdk/nexent/core/prompts/analyze_video_en.yaml | 13 ++
sdk/nexent/core/prompts/analyze_video_zh.yaml | 13 ++
sdk/nexent/core/tools/__init__.py | 4 +
sdk/nexent/core/tools/analyze_audio_tool.py | 169 ++++++++++++++++++
sdk/nexent/core/tools/analyze_image_tool.py | 20 +--
sdk/nexent/core/tools/analyze_video_tool.py | 156 ++++++++++++++++
.../core/utils/prompt_template_utils.py | 12 +-
sdk/nexent/multi_modal/utils.py | 20 ++-
test/backend/agents/test_create_agent_info.py | 134 +++++++++++++-
.../providers/test_silicon_provider.py | 72 +++++++-
test/backend/services/test_image_service.py | 82 +++++++--
.../services/test_model_management_service.py | 50 +++---
.../test_tool_configuration_service.py | 108 ++++++++++-
test/common/test_mocks.py | 2 +
test/sdk/core/agents/test_nexent_agent.py | 46 +++++
test/sdk/core/models/test_openai_vlm.py | 60 ++++++-
.../tools/test_analyze_audio_video_tool.py | 119 ++++++++++++
.../sdk/core/tools/test_analyze_image_tool.py | 6 +-
.../core/utils/test_prompt_template_utils.py | 24 ++-
test/sdk/multi_modal/test_load_save_object.py | 17 ++
55 files changed, 1674 insertions(+), 309 deletions(-)
delete mode 100644 docker/.env.bak
rename frontend/{tailwind.config.ts => tailwind.config.js} (100%)
create mode 100644 sdk/nexent/core/prompts/analyze_audio_en.yaml
create mode 100644 sdk/nexent/core/prompts/analyze_audio_zh.yaml
create mode 100644 sdk/nexent/core/prompts/analyze_video_en.yaml
create mode 100644 sdk/nexent/core/prompts/analyze_video_zh.yaml
create mode 100644 sdk/nexent/core/tools/analyze_audio_tool.py
create mode 100644 sdk/nexent/core/tools/analyze_video_tool.py
create mode 100644 test/sdk/core/tools/test_analyze_audio_video_tool.py
diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py
index 5a11b550b..9033f9ba0 100644
--- a/backend/agents/create_agent_info.py
+++ b/backend/agents/create_agent_info.py
@@ -21,7 +21,7 @@
from database.a2a_agent_db import PROTOCOL_JSONRPC
from services.memory_config_service import build_memory_context
-from services.image_service import get_vlm_model
+from services.image_service import get_video_understanding_model, get_vlm_model
from database.agent_db import search_agent_info_by_agent_id, query_sub_agents_id_list
from database.agent_version_db import query_current_version_no
from database.tool_db import search_tools_for_sub_agent
@@ -31,13 +31,36 @@
from utils.model_name_utils import add_repo_to_name
from utils.prompt_template_utils import get_agent_prompt_template
from utils.config_utils import tenant_config_manager, get_model_name_from_config
-from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE
+from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET
from consts.exceptions import ValidationError
logger = logging.getLogger("create_agent_info")
logger.setLevel(logging.DEBUG)
+def _build_internal_s3_url(file: dict) -> str:
+ """Build a valid S3 URL for internal tools from uploaded file metadata."""
+ if not isinstance(file, dict):
+ return ""
+
+ object_name = str(file.get("object_name") or "").strip().lstrip("/")
+ if object_name:
+ bucket = MINIO_DEFAULT_BUCKET or "nexent"
+ return f"s3://{bucket}/{object_name}"
+
+ url = str(file.get("url") or "").strip()
+ if not url or url.startswith("blob:") or url.startswith("s3:/blob:"):
+ return ""
+
+ if url.startswith("s3://"):
+ return url
+
+ if url.startswith("s3:/"):
+ return "s3://" + url.replace("s3:/", "", 1).lstrip("/")
+
+ return "s3:/" + url
+
+
def _get_skills_for_template(
agent_id: int,
tenant_id: str,
@@ -526,10 +549,17 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
}
elif tool_config.class_name == "AnalyzeImageTool":
tool_config.metadata = {
+ # get_vlm_model reads the first multimodal slot, now shown as image understanding.
"vlm_model": get_vlm_model(tenant_id=tenant_id),
"storage_client": minio_client,
"validate_url_access": lambda urls: validate_urls_access(urls, user_id)
}
+ elif tool_config.class_name in ["AnalyzeAudioTool", "AnalyzeVideoTool"]:
+ tool_config.metadata = {
+ "vlm_model": get_video_understanding_model(tenant_id=tenant_id),
+ "storage_client": minio_client,
+ "validate_url_access": lambda urls: validate_urls_access(urls, user_id)
+ }
tool_config_list.append(tool_config)
@@ -630,10 +660,12 @@ async def join_minio_file_description_to_query(
# Collect files from current message first (higher priority)
if minio_files and isinstance(minio_files, list):
for file in minio_files:
- if isinstance(file, dict) and file.get("url") and file.get("name"):
- url = file["url"]
- if url not in seen_urls:
- seen_urls.add(url)
+ if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")):
+ s3_url = _build_internal_s3_url(file)
+ if not s3_url:
+ continue
+ if s3_url not in seen_urls:
+ seen_urls.add(s3_url)
all_files.append(file)
# Collect files from historical messages (lower priority, already-deduped)
@@ -641,10 +673,12 @@ async def join_minio_file_description_to_query(
for msg in history:
if isinstance(msg, dict) and msg.get("minio_files"):
for file in msg["minio_files"]:
- if isinstance(file, dict) and file.get("url") and file.get("name"):
- url = file["url"]
- if url not in seen_urls:
- seen_urls.add(url)
+ if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")):
+ s3_url = _build_internal_s3_url(file)
+ if not s3_url:
+ continue
+ if s3_url not in seen_urls:
+ seen_urls.add(s3_url)
all_files.append(file)
# Enforce file count limit (keep most recent files by truncating from the end)
@@ -660,7 +694,7 @@ async def join_minio_file_description_to_query(
fixed_overhead = len(prefix) + len(suffix)
for i, file in enumerate(all_files):
- s3_url = f"s3:/{file['url']}"
+ s3_url = _build_internal_s3_url(file)
presigned_url = file.get("presigned_url", "")
# Build description with both URLs
@@ -712,8 +746,10 @@ def _format_minio_files_for_content(minio_files: Optional[List[dict]], max_files
if i >= max_files:
file_lines.append(f" - ... (and {len(minio_files) - max_files} more files)")
break
- if isinstance(file, dict) and file.get("url") and file.get("name"):
- s3_url = f"s3:/{file['url']}"
+ if isinstance(file, dict) and file.get("name") and (file.get("url") or file.get("object_name")):
+ s3_url = _build_internal_s3_url(file)
+ if not s3_url:
+ continue
presigned_url = file.get("presigned_url", "")
if presigned_url:
file_lines.append(
diff --git a/backend/consts/const.py b/backend/consts/const.py
index 77e86a185..6c8a43266 100644
--- a/backend/consts/const.py
+++ b/backend/consts/const.py
@@ -305,6 +305,8 @@ class VectorDatabaseType(str, Enum):
"multiEmbedding": "MULTI_EMBEDDING_ID",
"rerank": "RERANK_ID",
"vlm": "VLM_ID",
+ "vlm2": "VLM2_ID",
+ "vlm3": "VLM3_ID",
"stt": "STT_ID",
"tts": "TTS_ID"
}
diff --git a/backend/consts/model.py b/backend/consts/model.py
index bcaffcae7..31f611f8c 100644
--- a/backend/consts/model.py
+++ b/backend/consts/model.py
@@ -160,12 +160,22 @@ class STTModelConfig(BaseModel):
accessToken: Optional[str] = None
+def _empty_model_config() -> SingleModelConfig:
+ return SingleModelConfig(
+ modelName="",
+ displayName="",
+ apiConfig=ModelApiConfig(apiKey="", modelUrl="")
+ )
+
+
class ModelConfig(BaseModel):
llm: SingleModelConfig
embedding: SingleModelConfig
multiEmbedding: SingleModelConfig
rerank: SingleModelConfig
vlm: SingleModelConfig
+ vlm2: SingleModelConfig = Field(default_factory=_empty_model_config)
+ vlm3: SingleModelConfig = Field(default_factory=_empty_model_config)
stt: STTModelConfig
diff --git a/backend/prompts/managed_system_prompt_template_en.yaml b/backend/prompts/managed_system_prompt_template_en.yaml
index 67da8305c..94b35f66d 100644
--- a/backend/prompts/managed_system_prompt_template_en.yaml
+++ b/backend/prompts/managed_system_prompt_template_en.yaml
@@ -116,7 +116,7 @@ system_prompt: |-
→ Use **presigned_url** (already includes proxy prefix, format: `http://.../api/nb/v1/file/fetch?presigned_url=...`)
Directly use the **presigned_url** field provided in the user's uploaded file info. No need to construct or append anything.
2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image):
- → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`)
+ → Use **S3 URL** (format: `s3://nexent/attachments/xxx.pdf`)
Reason: Internal tools run inside Nexent and can directly access MinIO storage
{%- else %}
diff --git a/backend/prompts/manager_system_prompt_template_en.yaml b/backend/prompts/manager_system_prompt_template_en.yaml
index a4ffae074..feccecc9c 100644
--- a/backend/prompts/manager_system_prompt_template_en.yaml
+++ b/backend/prompts/manager_system_prompt_template_en.yaml
@@ -118,7 +118,7 @@ system_prompt: |-
→ Use **Download URL** (format: `https://minio.example.com/...?token=xxx`)
Reason: MCP tools run on external services and cannot access internal S3 storage
2. **Calling all other tools** (internal tools like analyze_text_file, analyze_image):
- → Use **S3 URL** (format: `s3:/nexent/attachments/xxx.pdf`)
+ → Use **S3 URL** (format: `s3://nexent/attachments/xxx.pdf`)
Reason: Internal tools run inside Nexent and can directly access MinIO storage
{%- else %}
- No tools are currently available
diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py
index 02fa7d8c6..080c438a8 100644
--- a/backend/services/agent_service.py
+++ b/backend/services/agent_service.py
@@ -1152,7 +1152,7 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str)
# Check if any tool is KnowledgeBaseSearchTool and set its metadata to empty dict
for tool in tool_list:
- if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "DataMateSearchTool"]:
+ if tool.class_name in ["KnowledgeBaseSearchTool", "AnalyzeTextFileTool", "AnalyzeImageTool", "AnalyzeAudioTool", "AnalyzeVideoTool", "DataMateSearchTool"]:
tool.metadata = {}
# Get model_id and model display name from agent_info
diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py
index 0ed29bfc5..2a585fd50 100644
--- a/backend/services/config_sync_service.py
+++ b/backend/services/config_sync_service.py
@@ -20,7 +20,7 @@
MODEL_ENGINE_ENABLED,
TENANT_NAME
)
-from database.model_management_db import get_model_id_by_display_name
+from database.model_management_db import get_model_id_by_display_name, get_model_records
from utils.config_utils import (
get_env_key,
get_model_name_from_config,
@@ -31,6 +31,20 @@
logger = logging.getLogger("config_sync_service")
+def get_model_id_for_config(model_type: str, display_name: str, tenant_id: str) -> Optional[int]:
+ if not display_name:
+ return None
+
+ records = get_model_records(
+ {"display_name": display_name, "model_type": model_type},
+ tenant_id
+ )
+ if records:
+ return records[0].get("model_id")
+
+ return get_model_id_by_display_name(display_name, tenant_id)
+
+
def handle_model_config(tenant_id: str, user_id: str, config_key: str, model_id: Optional[int], tenant_config_dict: dict) -> None:
"""
Handle model configuration updates, deletions, and settings operations
@@ -98,8 +112,8 @@ async def save_config_impl(config, tenant_id, user_id):
model_display_name = model_config.get("displayName")
config_key = get_env_key(model_type) + "_ID"
- model_id = get_model_id_by_display_name(
- model_display_name, tenant_id)
+ model_id = get_model_id_for_config(
+ model_type, model_display_name, tenant_id)
handle_model_config(tenant_id, user_id, config_key,
model_id, tenant_config_dict)
diff --git a/backend/services/image_service.py b/backend/services/image_service.py
index 8decbd541..8a924e9cc 100644
--- a/backend/services/image_service.py
+++ b/backend/services/image_service.py
@@ -31,7 +31,11 @@ async def proxy_image_impl(decoded_url: str):
def get_vlm_model(tenant_id: str):
- # Get the tenant config
+ """Return the configured image understanding model for AnalyzeImageTool.
+
+ The first multimodal model slot is still stored under MODEL_CONFIG_MAPPING["vlm"]
+ for compatibility, but it is the user-facing image understanding configuration.
+ """
vlm_model_config = tenant_config_manager.get_model_config(
key=MODEL_CONFIG_MAPPING["vlm"], tenant_id=tenant_id)
if not vlm_model_config:
@@ -48,3 +52,27 @@ def get_vlm_model(tenant_id: str):
max_tokens=512,
ssl_verify=vlm_model_config.get("ssl_verify", True),
)
+
+
+def get_image_understanding_model(tenant_id: str):
+ return get_vlm_model(tenant_id=tenant_id)
+
+
+def get_video_understanding_model(tenant_id: str):
+ """Return the configured video understanding model for multimodal tools."""
+ vlm_model_config = tenant_config_manager.get_model_config(
+ key=MODEL_CONFIG_MAPPING["vlm3"], tenant_id=tenant_id)
+ if not vlm_model_config:
+ return None
+ return OpenAIVLModel(
+ observer=MessageObserver(),
+ model_id=get_model_name_from_config(
+ vlm_model_config) if vlm_model_config else "",
+ api_base=vlm_model_config.get("base_url", ""),
+ api_key=vlm_model_config.get("api_key", ""),
+ temperature=0.7,
+ top_p=0.7,
+ frequency_penalty=0.5,
+ max_tokens=512,
+ ssl_verify=vlm_model_config.get("ssl_verify", True),
+ )
diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py
index a20b2a6ca..5044f50f0 100644
--- a/backend/services/model_health_service.py
+++ b/backend/services/model_health_service.py
@@ -125,7 +125,7 @@ async def _perform_connectivity_check(
ssl_verify=ssl_verify,
)
connectivity = await rerank_model.connectivity_check()
- elif model_type == "vlm":
+ elif model_type in ("vlm", "vlm2", "vlm3"):
observer = MessageObserver()
set_monitoring_operation("connectivity_check",
display_name=display_name)
diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py
index d012803be..268c00413 100644
--- a/backend/services/model_management_service.py
+++ b/backend/services/model_management_service.py
@@ -153,6 +153,13 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay
tenant_id, provider, model_type)
model_list_ids = {model.get("id")
for model in model_list} if model_list else set()
+ existing_model_map = {
+ add_repo_to_name(
+ model_repo=model["model_repo"],
+ model_name=model["model_name"],
+ ): model
+ for model in existing_model_list
+ }
# Delete existing models not present
for model in existing_model_list:
@@ -162,21 +169,20 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay
# Create or update new models
for model in model_list:
+ model["model_type"] = model_type
_, model_name = split_repo_name(
model["id"]) if model.get("id") else ("", "")
model_repo, model_name_only = split_repo_name(
model.get("id", "")) if model.get("id") else ("", "")
model_display_name = add_repo_to_name(model_repo, model_name_only)
if model_name:
- existing_model_by_display = get_model_by_display_name(
- model_display_name, tenant_id)
- if existing_model_by_display:
+ existing_model = existing_model_map.get(model_display_name)
+ if existing_model:
# Check if max_tokens has changed
- existing_max_tokens = existing_model_by_display.get(
- "max_tokens")
+ existing_max_tokens = existing_model.get("max_tokens")
new_max_tokens = model.get("max_tokens")
if new_max_tokens is not None and existing_max_tokens != new_max_tokens:
- update_model_record(existing_model_by_display["model_id"], {
+ update_model_record(existing_model["model_id"], {
"max_tokens": new_max_tokens}, user_id)
continue
diff --git a/backend/services/providers/silicon_provider.py b/backend/services/providers/silicon_provider.py
index ea41cc95d..130f2346e 100644
--- a/backend/services/providers/silicon_provider.py
+++ b/backend/services/providers/silicon_provider.py
@@ -1,4 +1,5 @@
import httpx
+import re
from typing import Dict, List
from consts.const import DEFAULT_LLM_MAX_TOKENS
@@ -6,6 +7,62 @@
from services.providers.base import AbstractModelProvider, _classify_provider_error
+SILICON_VLM_MODEL_KEYWORDS = (
+ "-vl",
+ "_vl",
+ "/vl",
+ ".vl",
+ "vl-",
+ "vision",
+ "visual",
+ "internvl",
+ "deepseek-vl",
+ "deepseekvl",
+ "glm-4v",
+ "minicpm-v",
+ "llava",
+ "kimi-vl",
+ "kimi-k2.5",
+ "kimi-k2.6",
+ "qvq",
+ "omni",
+ "qwen3.5",
+ "qwen3.6",
+)
+
+SILICON_VLM_METADATA_KEYWORDS = ("image", "video", "vision", "visual")
+
+
+def _contains_silicon_vlm_metadata(value) -> bool:
+ if isinstance(value, str):
+ lower_value = value.lower()
+ return any(keyword in lower_value for keyword in SILICON_VLM_METADATA_KEYWORDS)
+ if isinstance(value, list):
+ return any(_contains_silicon_vlm_metadata(item) for item in value)
+ if isinstance(value, dict):
+ return any(_contains_silicon_vlm_metadata(item) for item in value.values())
+ return False
+
+
+def _is_silicon_vlm_model(model: Dict) -> bool:
+ if _contains_silicon_vlm_metadata(model):
+ return True
+
+ model_id = str(model.get("id", "")).lower()
+ model_name = str(model.get("name", "")).lower()
+ searchable_text = f"{model_id} {model_name}"
+ if any(keyword in searchable_text for keyword in SILICON_VLM_MODEL_KEYWORDS):
+ return True
+
+ return bool(re.search(r"glm-\d+(?:\.\d+)?v", searchable_text))
+
+
+def _is_silicon_omni_model(model: Dict) -> bool:
+ model_id = str(model.get("id", "")).lower()
+ model_name = str(model.get("name", "")).lower()
+ return "omni" in f"{model_id} {model_name}"
+
+
class SiliconModelProvider(AbstractModelProvider):
"""Concrete implementation for SiliconFlow provider."""
@@ -25,12 +82,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
headers = {"Authorization": f"Bearer {model_api_key}"}
+ provider_model_type = "vlm" if model_type in ("vlm2", "vlm3") else model_type
+
# Choose endpoint by model type
- if model_type in ("llm", "vlm"):
+ if provider_model_type in ("llm", "vlm"):
silicon_url = f"{SILICON_GET_URL}?sub_type=chat"
- elif model_type in ("embedding", "multi_embedding"):
+ elif provider_model_type in ("embedding", "multi_embedding"):
silicon_url = f"{SILICON_GET_URL}?sub_type=embedding"
- elif model_type == "rerank":
+ elif provider_model_type == "rerank":
silicon_url = f"{SILICON_GET_URL}?sub_type=reranker"
else:
silicon_url = SILICON_GET_URL
@@ -40,17 +99,22 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
response.raise_for_status()
model_list: List[Dict] = response.json()["data"]
+ if model_type == "vlm3":
+ model_list = [item for item in model_list if _is_silicon_omni_model(item)]
+ elif provider_model_type == "vlm":
+ model_list = [item for item in model_list if _is_silicon_vlm_model(item)]
+
# Annotate models with canonical fields expected downstream
- if model_type in ("llm", "vlm"):
+ if provider_model_type in ("llm", "vlm"):
for item in model_list:
item["model_tag"] = "chat"
item["model_type"] = model_type
item["max_tokens"] = DEFAULT_LLM_MAX_TOKENS
- elif model_type in ("embedding", "multi_embedding"):
+ elif provider_model_type in ("embedding", "multi_embedding"):
for item in model_list:
item["model_tag"] = "embedding"
item["model_type"] = model_type
- elif model_type == "rerank":
+ elif provider_model_type == "rerank":
for item in model_list:
item["model_tag"] = "rerank"
item["model_type"] = model_type
diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py
index 5e5229ff6..d3f3d513b 100644
--- a/backend/services/tool_configuration_service.py
+++ b/backend/services/tool_configuration_service.py
@@ -38,7 +38,7 @@
from services.file_management_service import get_llm_model, validate_urls_access
from services.vectordatabase_service import get_embedding_model_by_index_name, get_rerank_model
from database.client import minio_client
-from services.image_service import get_vlm_model
+from services.image_service import get_video_understanding_model, get_vlm_model
from nexent.monitor import set_monitoring_context, set_monitoring_operation
from services.vectordatabase_service import get_vector_db_core
from utils.langchain_utils import discover_langchain_modules
@@ -765,6 +765,7 @@ def _validate_local_tool(
if not tenant_id or not user_id:
raise ToolExecutionException(
f"Tenant ID and User ID are required for {tool_name} validation")
+ # get_vlm_model reads the first multimodal slot, now shown as image understanding.
image_to_text_model = get_vlm_model(tenant_id=tenant_id)
vlm_display_name = getattr(
image_to_text_model, 'display_name', None)
@@ -778,6 +779,23 @@ def _validate_local_tool(
'validate_url_access': lambda urls: validate_urls_access(urls, user_id)
}
tool_instance = tool_class(**params)
+ elif tool_name in ["analyze_audio", "analyze_video"]:
+ if not tenant_id or not user_id:
+ raise ToolExecutionException(
+ f"Tenant ID and User ID are required for {tool_name} validation")
+ video_understanding_model = get_video_understanding_model(tenant_id=tenant_id)
+ model_display_name = getattr(
+ video_understanding_model, 'display_name', None)
+ set_monitoring_context(tenant_id=tenant_id)
+ set_monitoring_operation(
+ "tool_validation", display_name=model_display_name)
+ params = {
+ **instantiation_params,
+ 'vlm_model': video_understanding_model,
+ 'storage_client': minio_client,
+ 'validate_url_access': lambda urls: validate_urls_access(urls, user_id)
+ }
+ tool_instance = tool_class(**params)
elif tool_name == "analyze_text_file":
if not tenant_id or not user_id:
raise ToolExecutionException(
diff --git a/docker/.env.bak b/docker/.env.bak
deleted file mode 100644
index 24b53751b..000000000
--- a/docker/.env.bak
+++ /dev/null
@@ -1,168 +0,0 @@
-# ===== Necessary Configs (Necessary till now, will be migrated to frontend page) =====
-
-# Voice Service Config
-APPID=app_id
-TOKEN=token
-
-# ===== Non-essential Configs (Modify if you know what you are doing) =====
-
-CLUSTER=volcano_tts
-VOICE_TYPE=zh_male_jieshuonansheng_mars_bigtts
-SPEED_RATIO=1.3
-
-# ===== Proxy Configuration (Optional) =====
-
-# HTTP_PROXY=http://proxy-server:port
-# HTTPS_PROXY=http://proxy-server:port
-# NO_PROXY=localhost,127.0.0.1
-
-# ===== Backend Configuration (No need to modify at all) =====
-
-# Model Path Config
-CLIP_MODEL_PATH=/opt/models/clip-vit-base-patch32
-NLTK_DATA=/opt/models/nltk_data
-
-# Elasticsearch Service
-ELASTICSEARCH_HOST=http://nexent-elasticsearch:9200
-ELASTIC_PASSWORD=nexent@2025
-
-# Elasticsearch Memory Configuration
-ES_JAVA_OPTS="-Xms2g -Xmx2g"
-
-# Elasticsearch Disk Watermark Configuration
-ES_DISK_WATERMARK_LOW=85%
-ES_DISK_WATERMARK_HIGH=90%
-ES_DISK_WATERMARK_FLOOD_STAGE=95%
-
-# Main Services
-# Config service (port 5010) - Main API service for config operations
-CONFIG_SERVICE_URL=http://nexent-config:5010
-ELASTICSEARCH_SERVICE=http://nexent-config:5010/api
-
-# Runtime service (port 5014) - Runtime execution service for agent operations
-RUNTIME_SERVICE_URL=http://nexent-runtime:5014
-
-# MCP service (port 5011) - MCP protocol service
-NEXENT_MCP_SERVER=http://nexent-mcp:5011
-MCP_MANAGEMENT_API=http://nexent-mcp:5015
-
-# Data process service (port 5012) - Data processing service
-DATA_PROCESS_SERVICE=http://nexent-data-process:5012/api
-
-# Northbound service (port 5013) - Northbound API service
-NORTHBOUND_API_SERVER=http://nexent-northbound:5013/api
-
-# Postgres Config
-POSTGRES_HOST=nexent-postgresql
-POSTGRES_USER=root
-NEXENT_POSTGRES_PASSWORD=nexent@4321
-POSTGRES_DB=nexent
-POSTGRES_PORT=5432
-
-# Minio Config
-MINIO_ENDPOINT=http://nexent-minio:9000
-MINIO_ROOT_USER=nexent
-MINIO_ROOT_PASSWORD=nexent@4321
-MINIO_REGION=cn-north-1
-MINIO_DEFAULT_BUCKET=nexent
-
-# Redis Config
-REDIS_URL=redis://redis:6379/0
-REDIS_BACKEND_URL=redis://redis:6379/1
-
-# Model Engine Config
-MODEL_ENGINE_ENABLED=false
-
-# Supabase Config
-DASHBOARD_USERNAME=supabase
-DASHBOARD_PASSWORD=Huawei123
-
-# Supabase db Config
-SUPABASE_POSTGRES_PASSWORD=Huawei123
-SUPABASE_POSTGRES_HOST=db
-SUPABASE_POSTGRES_DB=supabase
-SUPABASE_POSTGRES_PORT=5436
-
-# Supabase Auth Config
-SITE_URL=http://localhost:3011
-SUPABASE_URL=http://supabase-kong-mini:8000
-API_EXTERNAL_URL=http://supabase-kong-mini:8000
-DISABLE_SIGNUP=false
-JWT_EXPIRY=3600
-DEBUG_JWT_EXPIRE_SECONDS=0
-
-# Supabase Configuration
-ENABLE_EMAIL_SIGNUP=true
-ENABLE_EMAIL_AUTOCONFIRM=true
-ENABLE_ANONYMOUS_USERS=false
-
-# Supabase Phone Config
-ENABLE_PHONE_SIGNUP=false
-ENABLE_PHONE_AUTOCONFIRM=false
-
-MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
-MAILER_URLPATHS_INVITE="/auth/v1/verify"
-MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
-MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
-
-INVITE_CODE=nexent2025
-
-# Terminal Tool SSH Key Path
-SSH_PRIVATE_KEY_PATH=/path/to/openssh-server/ssh-keys/openssh_server_key
-
-# ===== Data Processing Service Configuration =====
-
-# Redis Port
-REDIS_PORT=6379
-
-# Flower Monitoring
-FLOWER_PORT=5555
-
-# Ray Configuration
-RAY_ACTOR_NUM_CPUS=2
-RAY_DASHBOARD_PORT=8265
-RAY_DASHBOARD_HOST=0.0.0.0
-RAY_NUM_CPUS=4
-RAY_OBJECT_STORE_MEMORY_GB=0.25
-RAY_TEMP_DIR=/tmp/ray
-RAY_LOG_LEVEL=INFO
-
-# Service Control Flags
-DISABLE_RAY_DASHBOARD=true
-DISABLE_CELERY_FLOWER=true
-DOCKER_ENVIRONMENT=false
-ENABLE_UPLOAD_IMAGE=false
-
-# Celery Configuration
-CELERY_WORKER_PREFETCH_MULTIPLIER=1
-CELERY_TASK_TIME_LIMIT=3600
-ELASTICSEARCH_REQUEST_TIMEOUT=30
-
-# Worker Configuration
-QUEUES=process_q,forward_q
-WORKER_NAME=
-WORKER_CONCURRENCY=4
-
-# Skills Configuration
-SKILLS_PATH=/mnt/nexent/skills
-
-# Telemetry and Monitoring Configuration
-ENABLE_TELEMETRY=false
-SERVICE_NAME=nexent-backend
-JAEGER_ENDPOINT=http://localhost:14268/api/traces
-PROMETHEUS_PORT=8000
-TELEMETRY_SAMPLE_RATE=1.0
-LLM_SLOW_REQUEST_THRESHOLD_SECONDS=5.0
-LLM_SLOW_TOKEN_RATE_THRESHOLD=10.0
-
-# Market Backend Address
-MARKET_BACKEND=http://60.204.251.153:8010
-DEPLOYMENT_VERSION="speed"
-# Root dir
-ROOT_DIR="/c/Users/18270/nexent-data"
-TERMINAL_MOUNT_DIR="/opt/terminal"
-SSH_USERNAME="root"
-SSH_PASSWORD="731215"
-NEXENT_MCP_DOCKER_IMAGE="ccr.ccs.tencentyun.com/nexent-hub/nexent-mcp:v2.0.1"
-MINIO_ACCESS_KEY="72c31cb5b521511cea652723"
-MINIO_SECRET_KEY="m5gcSuKzZnp84CqmG7z5VKnd2C+H5U3PSr7eoJeygmI="
diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
index 909592345..7dbba48d7 100644
--- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
+++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
@@ -34,11 +34,17 @@ const TOOLS_REQUIRING_EMBEDDING = [
"knowledge_base_search",
];
-// Tool types that require VLM model
-const TOOLS_REQUIRING_VLM = [
+// Tool types that require the image understanding model
+const TOOLS_REQUIRING_IMAGE_UNDERSTANDING = [
"analyze_image",
];
+// Tool types that require the video understanding model
+const TOOLS_REQUIRING_VIDEO_UNDERSTANDING = [
+ "analyze_audio",
+ "analyze_video",
+];
+
function getToolKbType(
toolName: string
): "knowledge_base_search" | "dify_search" | "datamate_search" | "idata_search" | "haotian_search" | null {
@@ -53,9 +59,18 @@ function getToolKbType(
/**
* Check if a tool requires VLM model but VLM is not available
*/
-function isToolDisabledDueToVlm(toolName: string, vlmAvailable: boolean): boolean {
- if (!TOOLS_REQUIRING_VLM.includes(toolName)) return false;
- return !vlmAvailable;
+function isToolDisabledDueToVlm(
+ toolName: string,
+ imageUnderstandingAvailable: boolean,
+ videoUnderstandingAvailable: boolean
+): boolean {
+ if (TOOLS_REQUIRING_IMAGE_UNDERSTANDING.includes(toolName)) {
+ return !imageUnderstandingAvailable;
+ }
+ if (TOOLS_REQUIRING_VIDEO_UNDERSTANDING.includes(toolName)) {
+ return !videoUnderstandingAvailable;
+ }
+ return false;
}
/**
@@ -102,7 +117,11 @@ export default function ToolManagement({
// Use tool list hook for data management
const { availableTools } = useToolList();
- const { isVlmAvailable, isEmbeddingAvailable } = useConfig();
+ const {
+ isImageUnderstandingAvailable,
+ isVideoUnderstandingAvailable,
+ isEmbeddingAvailable,
+ } = useConfig();
// Prefetch knowledge bases for KB tools
const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases();
@@ -362,7 +381,11 @@ export default function ToolManagement({
const isSelected = originalSelectedToolIdsSet.has(
tool.id
);
- const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable);
+ const isDisabledDueToVlm = isToolDisabledDueToVlm(
+ tool.name,
+ isImageUnderstandingAvailable,
+ isVideoUnderstandingAvailable
+ );
const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable);
const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly;
// Tooltip priority: permission > VLM > Embedding
@@ -467,7 +490,11 @@ export default function ToolManagement({
>
{group.tools.map((tool) => {
const isSelected = originalSelectedToolIdsSet.has(tool.id);
- const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable);
+ const isDisabledDueToVlm = isToolDisabledDueToVlm(
+ tool.name,
+ isImageUnderstandingAvailable,
+ isVideoUnderstandingAvailable
+ );
const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable);
const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly;
// Tooltip priority: permission > VLM > Embedding
diff --git a/frontend/app/[locale]/chat/components/chatAttachment.tsx b/frontend/app/[locale]/chat/components/chatAttachment.tsx
index 5c9da8ec9..d12e939cd 100644
--- a/frontend/app/[locale]/chat/components/chatAttachment.tsx
+++ b/frontend/app/[locale]/chat/components/chatAttachment.tsx
@@ -87,6 +87,14 @@ const getFileIcon = (name: string, contentType?: string) => {
return ;
}
+ // Audio and video files are uploaded as regular attachments for multimodal tools.
+ if (chatConfig.fileIcons.audio.includes(extension) || fileType.startsWith("audio/")) {
+ return ;
+ }
+ if (chatConfig.fileIcons.video.includes(extension) || fileType.startsWith("video/")) {
+ return ;
+ }
+
// Compressed file
if (chatConfig.fileIcons.compressed.includes(extension)) {
return ;
@@ -230,4 +238,4 @@ export function ChatAttachment({
)}
);
-}
\ No newline at end of file
+}
diff --git a/frontend/app/[locale]/chat/components/chatInput.tsx b/frontend/app/[locale]/chat/components/chatInput.tsx
index 8de0d17eb..bcfc86f6b 100644
--- a/frontend/app/[locale]/chat/components/chatInput.tsx
+++ b/frontend/app/[locale]/chat/components/chatInput.tsx
@@ -96,10 +96,24 @@ const getFileIcon = (file: File) => {
return ;
}
+ if (chatConfig.fileIcons.audio.includes(extension) || fileType.startsWith("audio/")) {
+ return ;
+ }
+
+ if (chatConfig.fileIcons.video.includes(extension) || fileType.startsWith("video/")) {
+ return ;
+ }
+
// Default file icon
return ;
};
+const isSupportedMediaFile = (extension: string, fileType: string) =>
+ fileType.startsWith("audio/") ||
+ fileType.startsWith("video/") ||
+ chatConfig.audioExtensions.includes(extension) ||
+ chatConfig.videoExtensions.includes(extension);
+
// File limit constants from config
const MAX_FILE_COUNT = chatConfig.maxFileCount;
const MAX_FILE_SIZE = chatConfig.maxFileSize;
@@ -617,8 +631,9 @@ export function ChatInput({
chatConfig.supportedTextExtensions.includes(extension) ||
file.type === "text/csv" ||
file.type === "text/plain";
+ const isMedia = isSupportedMediaFile(extension, file.type);
- if (isImage || isDocument || isSupportedTextFile) {
+ if (isImage || isDocument || isSupportedTextFile || isMedia) {
// Create a preview URL for images
const previewUrl = isImage ? URL.createObjectURL(file) : undefined;
@@ -899,7 +914,7 @@ export function ChatInput({
id="file-upload-regular"
className="hidden"
onChange={handleFileUpload}
- accept={`image/*,${Object.values(chatConfig.fileIcons).flat().map(ext => `.${ext}`).join(',')}`}
+ accept={`image/*,audio/*,video/*,${Object.values(chatConfig.fileIcons).flat().map(ext => `.${ext}`).join(',')}`}
multiple
/>
@@ -1026,8 +1041,9 @@ export function ChatInput({
chatConfig.supportedTextExtensions.includes(extension) ||
fileType === "text/csv" ||
fileType === "text/plain";
+ const isMedia = isSupportedMediaFile(extension, fileType);
- return !(isImage || isDocument || isSupportedTextFile);
+ return !(isImage || isDocument || isSupportedTextFile || isMedia);
});
// Regular mode, keep the original rendering logic
diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx
index 6e0de48b5..0f3c99715 100644
--- a/frontend/app/[locale]/chat/internal/chatInterface.tsx
+++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx
@@ -38,7 +38,7 @@ import {
extractAssistantMsgFromResponse,
} from "@/lib/chatMessageExtractor";
-import { Layout } from "antd";
+import { Layout, message } from "antd";
import log from "@/lib/logger";
const stepIdCounter = { current: 0 };
@@ -268,9 +268,23 @@ export function ChatInterface() {
// Use preprocessing function to upload attachments
const uploadResult = await uploadAttachments(attachments, t);
+ if (uploadResult.error) {
+ message.error(`${t("chatPreprocess.fileUploadFailed")} ${uploadResult.error}`);
+ setIsLoading(false);
+ return;
+ }
uploadedFileUrls = uploadResult.uploadedFileUrls;
objectNames = uploadResult.objectNames; // Get object name mapping
presignedUrls = uploadResult.presignedUrls; // Get presigned URLs for external access
+
+ const missingUploads = attachments.filter(
+ (attachment) => !uploadedFileUrls[attachment.file.name] || !objectNames[attachment.file.name]
+ );
+ if (missingUploads.length > 0) {
+ message.error(`${t("chatPreprocess.fileUploadFailed")} ${missingUploads.map((item) => item.file.name).join(", ")}`);
+ setIsLoading(false);
+ return;
+ }
}
// Use preprocessing function to create message attachments
diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
index 11391c133..85737c251 100644
--- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
+++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
@@ -66,6 +66,23 @@ const DEFAULT_FORM_STATE = {
accessToken: "",
};
+const resolveConnectivityModelType = (type: ModelType): ModelType =>
+ type === MODEL_TYPES.VLM2 || type === MODEL_TYPES.VLM3
+ ? (MODEL_TYPES.VLM as ModelType)
+ : type;
+
+const resolveConfigKey = (type: ModelType): string =>
+ type;
+
+const isVlmConfigType = (type: ModelType): boolean =>
+ type === MODEL_TYPES.VLM || type === MODEL_TYPES.VLM2 || type === MODEL_TYPES.VLM3;
+
+const emptyModelConfig = {
+ modelName: "",
+ displayName: "",
+ apiConfig: { apiKey: "", modelUrl: "" },
+};
+
// Connectivity status type comes from utils
// Helper function to translate error messages from backend
@@ -196,7 +213,7 @@ export const ModelAddDialog = ({
}: ModelAddDialogProps) => {
const { t } = useTranslation();
const { message } = App.useApp();
- const { updateModelConfig, saveConfig } = useConfig();
+ const { modelConfig: currentModelConfig, updateModelConfig, saveConfig } = useConfig();
// Parse backend error message and return i18n key with params
const parseModelError = (
@@ -475,7 +492,7 @@ export const ModelAddDialog = ({
const modelType =
form.type === MODEL_TYPES.EMBEDDING && form.isMultimodal
? (MODEL_TYPES.MULTI_EMBEDDING as ModelType)
- : form.type;
+ : resolveConnectivityModelType(form.type);
let connectivity = false;
@@ -615,6 +632,32 @@ export const ModelAddDialog = ({
});
}
+ if (isVlmConfigType(form.type) && enabledModels.length > 0) {
+ const selectedModel = enabledModels[0];
+ const selectedDisplayName = selectedModel.displayName || selectedModel.id || "";
+ const configKey = resolveConfigKey(form.type);
+ const vlmConfigUpdate: any = {
+ [configKey]: {
+ modelName: selectedModel.id || selectedModel.model_name || "",
+ displayName: selectedDisplayName,
+ apiConfig: {
+ apiKey: form.apiKey,
+ modelUrl: "",
+ },
+ },
+ };
+ for (const key of [MODEL_TYPES.VLM, MODEL_TYPES.VLM2, MODEL_TYPES.VLM3]) {
+ if (
+ key !== configKey &&
+ currentModelConfig?.[key]?.displayName === selectedDisplayName
+ ) {
+ vlmConfigUpdate[key] = emptyModelConfig;
+ }
+ }
+ updateModelConfig(vlmConfigUpdate);
+ await persistModelConfig();
+ }
+
// Reset form state and close dialog on success
resetForm();
handleClose();
@@ -772,6 +815,7 @@ export const ModelAddDialog = ({
// Update the local storage according to the model type
let configUpdate: any = {};
+ const configKey = resolveConfigKey(form.type);
switch (modelType) {
case MODEL_TYPES.LLM:
@@ -784,7 +828,17 @@ export const ModelAddDialog = ({
configUpdate = { multiEmbedding: modelConfig };
break;
case MODEL_TYPES.VLM:
- configUpdate = { vlm: modelConfig };
+ case MODEL_TYPES.VLM2:
+ case MODEL_TYPES.VLM3:
+ configUpdate = { [configKey]: modelConfig };
+ for (const key of [MODEL_TYPES.VLM, MODEL_TYPES.VLM2, MODEL_TYPES.VLM3]) {
+ if (
+ key !== configKey &&
+ currentModelConfig?.[key]?.displayName === modelConfig.displayName
+ ) {
+ configUpdate[key] = emptyModelConfig;
+ }
+ }
break;
case MODEL_TYPES.RERANK:
configUpdate = { rerank: modelConfig };
@@ -926,7 +980,15 @@ export const ModelAddDialog = ({
-
+
+
+
diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx
index ad3cf0391..96eebc2e4 100644
--- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx
+++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx
@@ -101,6 +101,8 @@ export const ModelDeleteDialog = ({
border: "border-purple-100",
};
case MODEL_TYPES.VLM:
+ case MODEL_TYPES.VLM2:
+ case MODEL_TYPES.VLM3:
return {
bg: "bg-yellow-50",
text: "text-yellow-600",
@@ -143,6 +145,8 @@ export const ModelDeleteDialog = ({
case MODEL_TYPES.TTS:
return "🔊";
case MODEL_TYPES.VLM:
+ case MODEL_TYPES.VLM2:
+ case MODEL_TYPES.VLM3:
return "👁️";
default:
return "⚙️";
@@ -167,6 +171,10 @@ export const ModelDeleteDialog = ({
return t("model.type.tts");
case MODEL_TYPES.VLM:
return t("model.type.vlm");
+ case MODEL_TYPES.VLM2:
+ return `${t("model.type.vlm")}2`;
+ case MODEL_TYPES.VLM3:
+ return `${t("model.type.vlm")}3`;
default:
return t("model.type.unknown");
}
@@ -346,7 +354,10 @@ export const ModelDeleteDialog = ({
if (cfgUrl && cfgUrl.trim() !== "") return cfgUrl;
}
if (type === MODEL_TYPES.VLM) {
- const cfgUrl = modelConfig?.vlm?.apiConfig?.modelUrl;
+ const cfgUrl =
+ modelConfig?.vlm?.apiConfig?.modelUrl ||
+ modelConfig?.vlm2?.apiConfig?.modelUrl ||
+ modelConfig?.vlm3?.apiConfig?.modelUrl;
if (cfgUrl && cfgUrl.trim() !== "") return cfgUrl;
}
if (type === MODEL_TYPES.LLM) {
@@ -503,6 +514,22 @@ export const ModelDeleteDialog = ({
};
}
+ if (modelConfig.vlm2?.displayName === displayName) {
+ configUpdates.vlm2 = {
+ modelName: "",
+ displayName: "",
+ apiConfig: { apiKey: "", modelUrl: "" },
+ };
+ }
+
+ if (modelConfig.vlm3?.displayName === displayName) {
+ configUpdates.vlm3 = {
+ modelName: "",
+ displayName: "",
+ apiConfig: { apiKey: "", modelUrl: "" },
+ };
+ }
+
if (modelConfig.stt.displayName === displayName) {
configUpdates.stt = { modelName: "", displayName: "" };
}
@@ -1028,6 +1055,8 @@ export const ModelDeleteDialog = ({
MODEL_TYPES.MULTI_EMBEDDING,
MODEL_TYPES.RERANK,
MODEL_TYPES.VLM,
+ MODEL_TYPES.VLM2,
+ MODEL_TYPES.VLM3,
MODEL_TYPES.STT,
MODEL_TYPES.TTS,
] as ModelType[]
diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
index 3114c5535..cdac265f5 100644
--- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
+++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
@@ -87,6 +87,10 @@ export const ModelEditDialog = ({
form.type === MODEL_TYPES.EMBEDDING ||
form.type === MODEL_TYPES.MULTI_EMBEDDING;
const isRerankModel = form.type === MODEL_TYPES.RERANK;
+ const connectivityModelType =
+ form.type === MODEL_TYPES.VLM2 || form.type === MODEL_TYPES.VLM3
+ ? (MODEL_TYPES.VLM as ModelType)
+ : form.type;
const isFormValid = () => {
return form.name.trim() !== "" && form.url.trim() !== "";
@@ -106,11 +110,9 @@ export const ModelEditDialog = ({
});
try {
- const modelType = form.type as ModelType;
-
const config = {
modelName: form.name,
- modelType: modelType,
+ modelType: connectivityModelType,
baseUrl: form.url,
apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey,
maxTokens:
@@ -205,6 +207,8 @@ export const ModelEditDialog = ({
embedding: MODEL_TYPES.EMBEDDING,
multi_embedding: MODEL_TYPES.MULTI_EMBEDDING,
vlm: MODEL_TYPES.VLM,
+ vlm2: MODEL_TYPES.VLM2,
+ vlm3: MODEL_TYPES.VLM3,
rerank: MODEL_TYPES.RERANK,
tts: MODEL_TYPES.TTS,
stt: MODEL_TYPES.STT,
@@ -481,4 +485,4 @@ export const ProviderConfigEditDialog = ({
)
-}
\ No newline at end of file
+}
diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx
index 07eee5c06..36fcdbb31 100644
--- a/frontend/app/[locale]/models/components/modelConfig.tsx
+++ b/frontend/app/[locale]/models/components/modelConfig.tsx
@@ -56,7 +56,11 @@ const getModelData = (t: any) => ({
},
multimodal: {
title: t("modelConfig.category.multimodal"),
- options: [{ id: MODEL_TYPES.VLM, name: t("modelConfig.option.vlmModel") }],
+ options: [
+ { id: MODEL_TYPES.VLM, name: t("modelConfig.option.imageUnderstandingModel") },
+ { id: MODEL_TYPES.VLM2, name: t("modelConfig.option.imageGenerationModel") },
+ { id: MODEL_TYPES.VLM3, name: t("modelConfig.option.videoUnderstandingModel") },
+ ],
},
voice: {
title: t("modelConfig.category.voice"),
@@ -142,7 +146,7 @@ export const ModelConfigSection = forwardRef<
llm: { main: "" },
embedding: { embedding: "", multi_embedding: "" },
reranker: { reranker: "" },
- multimodal: { vlm: "" },
+ multimodal: { vlm: "", vlm2: "", vlm3: "" },
voice: { tts: "", stt: "" },
});
@@ -284,11 +288,23 @@ export const ModelConfigSection = forwardRef<
: true;
const vlm = modelConfig.vlm.displayName;
+ const vlm2 = modelConfig.vlm2?.displayName || "";
+ const vlm3 = modelConfig.vlm3?.displayName || "";
const vlmExists = vlm
? allModels.some(
(m) => m.displayName === vlm && m.type === MODEL_TYPES.VLM
)
: true;
+ const vlm2Exists = vlm2
+ ? allModels.some(
+ (m) => m.displayName === vlm2 && m.type === MODEL_TYPES.VLM2
+ )
+ : true;
+ const vlm3Exists = vlm3
+ ? allModels.some(
+ (m) => m.displayName === vlm3 && m.type === MODEL_TYPES.VLM3
+ )
+ : true;
const stt = modelConfig.stt.displayName;
const sttExists = stt
@@ -318,6 +334,8 @@ export const ModelConfigSection = forwardRef<
},
multimodal: {
vlm: vlmExists ? vlm : "",
+ vlm2: vlm2Exists ? vlm2 : "",
+ vlm3: vlm3Exists ? vlm3 : "",
},
voice: {
tts: ttsExists ? tts : "",
@@ -363,6 +381,14 @@ export const ModelConfigSection = forwardRef<
configUpdates.vlm = { modelName: "", displayName: "" };
}
+ if (!vlm2Exists && vlm2) {
+ configUpdates.vlm2 = { modelName: "", displayName: "" };
+ }
+
+ if (!vlm3Exists && vlm3) {
+ configUpdates.vlm3 = { modelName: "", displayName: "" };
+ }
+
if (!sttExists && stt) {
configUpdates.stt = { modelName: "", displayName: "" };
}
@@ -385,6 +411,8 @@ export const ModelConfigSection = forwardRef<
!!modelConfig.multiEmbedding.modelName ||
!!modelConfig.rerank.modelName ||
!!modelConfig.vlm.modelName ||
+ !!modelConfig.vlm2?.modelName ||
+ !!modelConfig.vlm3?.modelName ||
!!modelConfig.tts.modelName ||
!!modelConfig.stt.modelName;
@@ -441,11 +469,13 @@ export const ModelConfigSection = forwardRef<
const hasEmbedding = !!modelConfig.embedding.modelName;
const hasReranker = !!modelConfig.rerank.modelName;
const hasVlm = !!modelConfig.vlm.modelName;
+ const hasVlm2 = !!modelConfig.vlm2?.modelName;
+ const hasVlm3 = !!modelConfig.vlm3?.modelName;
const hasTts = !!modelConfig.tts.modelName;
const hasStt = !!modelConfig.stt.modelName;
hasSelectedModels =
- hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasTts || hasStt;
+ hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasVlm2 || hasVlm3 || hasTts || hasStt;
if (hasSelectedModels) {
currentSelectedModels.llm.main = modelConfig.llm.modelName;
@@ -455,6 +485,8 @@ export const ModelConfigSection = forwardRef<
modelConfig.multiEmbedding.modelName || "";
currentSelectedModels.reranker.reranker = modelConfig.rerank.modelName;
currentSelectedModels.multimodal.vlm = modelConfig.vlm.modelName;
+ currentSelectedModels.multimodal.vlm2 = modelConfig.vlm2?.modelName || "";
+ currentSelectedModels.multimodal.vlm3 = modelConfig.vlm3?.modelName || "";
currentSelectedModels.voice.tts = modelConfig.tts.modelName;
currentSelectedModels.voice.stt = modelConfig.stt.modelName;
} else {
@@ -492,7 +524,7 @@ export const ModelConfigSection = forwardRef<
} else if (category === "reranker") {
modelType = MODEL_TYPES.RERANK;
} else if (category === "multimodal") {
- modelType = MODEL_TYPES.VLM;
+ modelType = optionId as ModelType;
} else if (category === MODEL_TYPES.EMBEDDING) {
modelType =
optionId === MODEL_TYPES.MULTI_EMBEDDING
@@ -654,7 +686,7 @@ export const ModelConfigSection = forwardRef<
} else if (category === "reranker") {
modelType = MODEL_TYPES.RERANK;
} else if (category === "multimodal") {
- modelType = MODEL_TYPES.VLM;
+ modelType = option as ModelType;
} else if (category === MODEL_TYPES.EMBEDDING) {
modelType =
option === MODEL_TYPES.MULTI_EMBEDDING
@@ -679,7 +711,7 @@ export const ModelConfigSection = forwardRef<
) {
configKey = "multiEmbedding";
} else if (category === "multimodal") {
- configKey = MODEL_TYPES.VLM;
+ configKey = option;
} else if (category === "reranker") {
configKey = MODEL_TYPES.RERANK;
} else if (category === "voice" && option === "tts") {
@@ -1005,7 +1037,7 @@ export const ModelConfigSection = forwardRef<
? MODEL_TYPES.TTS
: MODEL_TYPES.STT
: key === "multimodal"
- ? MODEL_TYPES.VLM
+ ? (option.id as ModelType)
: key === MODEL_TYPES.EMBEDDING &&
option.id === MODEL_TYPES.MULTI_EMBEDDING
? MODEL_TYPES.MULTI_EMBEDDING
diff --git a/frontend/const/chatConfig.ts b/frontend/const/chatConfig.ts
index fc0dbe6d5..27b3b887d 100644
--- a/frontend/const/chatConfig.ts
+++ b/frontend/const/chatConfig.ts
@@ -38,6 +38,12 @@ export const chatConfig = {
// Supported document file extensions
documentExtensions: ["pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "epub", "html", "xml"],
+
+ // Supported audio file extensions
+ audioExtensions: ["mp3", "wav", "m4a", "aac", "ogg", "oga", "flac", "webm"],
+
+ // Supported video file extensions
+ videoExtensions: ["mp4", "mov", "m4v", "avi", "mkv", "webm", "wmv", "flv"],
// Supported text document extensions
supportedTextExtensions: ["md", "markdown", "txt", "csv", "json"],
@@ -73,6 +79,12 @@ export const chatConfig = {
// Compressed file
compressed: ["zip", "rar", "7z", "tar", "gz"],
+
+ // Audio files
+ audio: ["mp3", "wav", "m4a", "aac", "ogg", "oga", "flac", "webm"],
+
+ // Video files
+ video: ["mp4", "mov", "m4v", "avi", "mkv", "wmv", "flv"],
},
// File preview type constants
@@ -148,4 +160,4 @@ export const MESSAGE_ROLES = {
USER: "user" as const,
ASSISTANT: "assistant" as const,
SYSTEM: "system" as const,
-} as const;
\ No newline at end of file
+} as const;
diff --git a/frontend/const/modelConfig.ts b/frontend/const/modelConfig.ts
index 9bdc5a4a8..b7762ace0 100644
--- a/frontend/const/modelConfig.ts
+++ b/frontend/const/modelConfig.ts
@@ -7,6 +7,8 @@ export const MODEL_TYPES = {
STT: "stt",
TTS: "tts",
VLM: "vlm",
+ VLM2: "vlm2",
+ VLM3: "vlm3",
} as const;
// Model source constants
diff --git a/frontend/hooks/model/useDashscopeModelList.ts b/frontend/hooks/model/useDashscopeModelList.ts
index b44348fe5..5d1035e8a 100644
--- a/frontend/hooks/model/useDashscopeModelList.ts
+++ b/frontend/hooks/model/useDashscopeModelList.ts
@@ -39,7 +39,9 @@ export const useDashscopeModelList = ({
const modelType =
form.type === "embedding" && form.isMultimodal
? ("multi_embedding" as ModelType)
- : form.type;
+ : form.type === "vlm2" || form.type === "vlm3"
+ ? ("vlm" as ModelType)
+ : form.type;
try {
// Use manage interface if tenantId is provided (for super admin)
diff --git a/frontend/hooks/model/useTokenponyModelList.ts b/frontend/hooks/model/useTokenponyModelList.ts
index 0a7e23581..0c502a404 100644
--- a/frontend/hooks/model/useTokenponyModelList.ts
+++ b/frontend/hooks/model/useTokenponyModelList.ts
@@ -39,7 +39,9 @@ export const useTokenPonyModelList = ({
const modelType =
form.type === "embedding" && form.isMultimodal
? ("multi_embedding" as ModelType)
- : form.type;
+ : form.type === "vlm2" || form.type === "vlm3"
+ ? ("vlm" as ModelType)
+ : form.type;
try {
// Use manage interface if tenantId is provided (for super admin)
diff --git a/frontend/hooks/useConfig.ts b/frontend/hooks/useConfig.ts
index 8d4c4ccea..4d095e681 100644
--- a/frontend/hooks/useConfig.ts
+++ b/frontend/hooks/useConfig.ts
@@ -76,6 +76,22 @@ const defaultConfig: GlobalConfig = {
modelUrl: "",
},
},
+ vlm2: {
+ modelName: "",
+ displayName: "",
+ apiConfig: {
+ apiKey: "",
+ modelUrl: "",
+ },
+ },
+ vlm3: {
+ modelName: "",
+ displayName: "",
+ apiConfig: {
+ apiKey: "",
+ modelUrl: "",
+ },
+ },
stt: {
modelName: "",
displayName: "",
@@ -161,6 +177,8 @@ function transformBackendToFrontend(backendConfig: any): GlobalConfig {
),
rerank: transformModelEntry(backendConfig.models.rerank),
vlm: transformModelEntry(backendConfig.models.vlm),
+ vlm2: transformModelEntry(backendConfig.models.vlm2),
+ vlm3: transformModelEntry(backendConfig.models.vlm3),
stt: transformVoiceModelEntry(backendConfig.models.stt),
tts: transformVoiceModelEntry(backendConfig.models.tts),
}
@@ -195,7 +213,10 @@ function loadConfigFromStorage(): GlobalConfig | null {
if (storedModelConfig) {
try {
- mergedConfig.models = JSON.parse(storedModelConfig);
+ mergedConfig.models = deepMerge(
+ mergedConfig.models,
+ JSON.parse(storedModelConfig)
+ );
} catch (error) {
log.error("Failed to parse model config:", error);
}
@@ -285,7 +306,24 @@ export function useConfig() {
const config: GlobalConfig = (query.data as GlobalConfig | undefined) ?? defaultConfig;
// Whether config has selected a VLM model
- const isVlmAvailable = !!(config?.models?.vlm?.modelName || config?.models?.vlm?.displayName);
+ const isVlmAvailable = !!(
+ config?.models?.vlm?.modelName ||
+ config?.models?.vlm?.displayName ||
+ config?.models?.vlm2?.modelName ||
+ config?.models?.vlm2?.displayName ||
+ config?.models?.vlm3?.modelName ||
+ config?.models?.vlm3?.displayName
+ );
+
+ const isImageUnderstandingAvailable = !!(
+ config?.models?.vlm?.modelName ||
+ config?.models?.vlm?.displayName
+ );
+
+ const isVideoUnderstandingAvailable = !!(
+ config?.models?.vlm3?.modelName ||
+ config?.models?.vlm3?.displayName
+ );
// Whether config has selected an Embedding model
const isEmbeddingAvailable = !!(config?.models?.embedding?.modelName || config?.models?.embedding?.displayName);
@@ -368,6 +406,8 @@ export function useConfig() {
appConfig: config?.app,
modelConfig: config?.models,
isVlmAvailable,
+ isImageUnderstandingAvailable,
+ isVideoUnderstandingAvailable,
isEmbeddingAvailable,
defaultLlmModelName,
updateAppConfig,
diff --git a/frontend/lib/chat/chatAttachmentUtils.ts b/frontend/lib/chat/chatAttachmentUtils.ts
index c85615b4e..bff686ca1 100644
--- a/frontend/lib/chat/chatAttachmentUtils.ts
+++ b/frontend/lib/chat/chatAttachmentUtils.ts
@@ -69,6 +69,19 @@ export const uploadAttachments = async (
});
}
+ const failedResults = uploadResult.results.filter((result) => !result.success);
+ if (failedResults.length > 0 || uploadResult.success_count < attachments.length) {
+ const failedMessage = failedResults
+ .map((result) => `${result.file_name || "file"}: ${result.error || "Upload failed"}`)
+ .join("; ");
+ return {
+ uploadedFileUrls,
+ objectNames,
+ presignedUrls,
+ error: failedMessage || "Upload failed",
+ };
+ }
+
return { uploadedFileUrls, objectNames, presignedUrls };
} catch (error) {
log.error(t("chatPreprocess.fileUploadFailed"), error);
diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json
index 22c17c2ca..2b22c3156 100644
--- a/frontend/public/locales/en/common.json
+++ b/frontend/public/locales/en/common.json
@@ -1,4 +1,4 @@
-{
+{
"assistant.name": "Nexent",
"mainPage.layout.title": "Nexent | AI Agents",
@@ -811,6 +811,9 @@
"model.type.llm": "Large Language Model",
"model.type.embedding": "Embedding Model",
"model.type.vlm": "Vision Language Model",
+ "model.type.imageUnderstanding": "Image Understanding Model",
+ "model.type.imageGeneration": "Image Generation Model",
+ "model.type.videoUnderstanding": "Video Understanding Model",
"model.type.rerank": "Rerank Model",
"model.type.stt": "Speech-to-Text Model",
"model.type.tts": "Text-to-Speech Model",
@@ -879,6 +882,9 @@
"modelConfig.option.multiEmbeddingModel": "Multimodal Embedding Model",
"modelConfig.option.rerankerModel": "Reranker Model",
"modelConfig.option.vlmModel": "Vision Language Model",
+ "modelConfig.option.imageUnderstandingModel": "Image Understanding Model",
+ "modelConfig.option.imageGenerationModel": "Image Generation Model",
+ "modelConfig.option.videoUnderstandingModel": "Video Understanding Model",
"modelConfig.option.ttsModel": "Text-to-Speech Model",
"modelConfig.option.sttModel": "Speech-to-Text Model",
"modelConfig.error.loadList": "Failed to load model list:",
diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json
index 1cc83a802..b0e9d7ef1 100644
--- a/frontend/public/locales/zh/common.json
+++ b/frontend/public/locales/zh/common.json
@@ -1,4 +1,4 @@
-{
+{
"assistant.name": "Nexent",
"mainPage.layout.title": "Nexent | 智能问答",
@@ -811,6 +811,9 @@
"model.type.llm": "大语言模型",
"model.type.embedding": "向量模型",
"model.type.vlm": "视觉语言模型",
+ "model.type.imageUnderstanding": "图片理解模型",
+ "model.type.imageGeneration": "图片生成模型",
+ "model.type.videoUnderstanding": "视频理解模型",
"model.type.rerank": "重排模型",
"model.type.stt": "语音识别模型",
"model.type.tts": "语音合成模型",
@@ -880,6 +883,9 @@
"modelConfig.option.multiEmbeddingModel": "多模态向量模型",
"modelConfig.option.rerankerModel": "重排模型",
"modelConfig.option.vlmModel": "视觉语言模型",
+ "modelConfig.option.imageUnderstandingModel": "图片理解模型",
+ "modelConfig.option.imageGenerationModel": "图片生成模型",
+ "modelConfig.option.videoUnderstandingModel": "视频理解模型",
"modelConfig.option.ttsModel": "语音合成模型",
"modelConfig.option.sttModel": "语音识别模型",
"modelConfig.error.loadList": "加载模型列表失败:",
diff --git a/frontend/tailwind.config.ts b/frontend/tailwind.config.js
similarity index 100%
rename from frontend/tailwind.config.ts
rename to frontend/tailwind.config.js
diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts
index a9f918d71..1a1df2b98 100644
--- a/frontend/types/modelConfig.ts
+++ b/frontend/types/modelConfig.ts
@@ -31,6 +31,8 @@ export type ModelType =
| "stt"
| "tts"
| "vlm"
+ | "vlm2"
+ | "vlm3"
| "multi_embedding";
// Model option interface
@@ -99,6 +101,8 @@ export interface ModelConfig {
multiEmbedding: SingleModelConfig;
rerank: SingleModelConfig;
vlm: SingleModelConfig;
+ vlm2: SingleModelConfig;
+ vlm3: SingleModelConfig;
stt: STTModelConfig;
tts: TTSModelConfig;
}
diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py
index 023c8348e..c0dc83efc 100644
--- a/sdk/nexent/core/agents/nexent_agent.py
+++ b/sdk/nexent/core/agents/nexent_agent.py
@@ -115,7 +115,7 @@ def create_local_tool(self, tool_config: ToolConfig):
data_process_service_url=tool_config.metadata.get("data_process_service_url", []),
validate_url_access=validate_url_access,
**params)
- elif class_name == "AnalyzeImageTool":
+ elif class_name in ["AnalyzeImageTool", "AnalyzeAudioTool", "AnalyzeVideoTool"]:
# Extract validate_url_access from metadata if it's callable
validate_url_access = tool_config.metadata.get("validate_url_access") if tool_config.metadata else None
if validate_url_access is not None and not callable(validate_url_access):
@@ -493,4 +493,4 @@ def _val_width(vals, extra_val=None):
# Optional: write to local file
with open("nexent_context_metrics.log", "a", encoding="utf-8") as f:
- f.write("\n".join(lines) + "\n")
\ No newline at end of file
+ f.write("\n".join(lines) + "\n")
diff --git a/sdk/nexent/core/models/openai_vlm.py b/sdk/nexent/core/models/openai_vlm.py
index 1babb0057..cbc7388d6 100644
--- a/sdk/nexent/core/models/openai_vlm.py
+++ b/sdk/nexent/core/models/openai_vlm.py
@@ -126,6 +126,47 @@ def prepare_image_message(self, image_input: Union[str, BinaryIO], system_prompt
return messages
+ def prepare_media_message(
+ self,
+ media_input: Union[str, BinaryIO],
+ media_type: str,
+ content_type: str,
+ system_prompt: str) -> List[Dict[str, Any]]:
+ """
+ Prepare an OpenAI-compatible multimodal message for audio or video inputs.
+
+ Args:
+ media_input: Media file path or file stream object.
+ media_type: Either "audio" or "video".
+ content_type: MIME type for the data URL.
+ system_prompt: System prompt.
+
+ Returns:
+ List[Dict[str, Any]]: Prepared message list.
+ """
+ if media_type not in ("audio", "video"):
+ raise ValueError(f"Unsupported media type: {media_type}")
+
+ base64_media = self.encode_image(media_input)
+ media_url_key = f"{media_type}_url"
+ media_config: Dict[str, Any] = {"url": f"data:{content_type};base64,{base64_media}"}
+ if media_type == "video":
+ media_config.update({"detail": "high", "max_frames": 16, "fps": 1})
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": media_url_key,
+ media_url_key: media_config
+ },
+ {"type": "text", "text": system_prompt}
+ ]
+ }
+ ]
+ return messages
+
def analyze_image(self, image_input: Union[str, BinaryIO],
system_prompt: str = "Please describe this picture concisely and carefully, within 200 words.", stream: bool = True,
**kwargs) -> ChatMessage:
@@ -144,3 +185,23 @@ def analyze_image(self, image_input: Union[str, BinaryIO],
messages = self.prepare_image_message(image_input, system_prompt)
# Call __call__ explicitly so instance-level mocks work in tests.
return self.__call__(messages=messages, **kwargs)
+
+ def analyze_audio(
+ self,
+ audio_input: Union[str, BinaryIO],
+ system_prompt: str = "Please analyze this audio carefully.",
+ content_type: str = "audio/mpeg",
+ **kwargs) -> ChatMessage:
+ """Analyze audio content using the configured multimodal model."""
+ messages = self.prepare_media_message(audio_input, "audio", content_type, system_prompt)
+ return self.__call__(messages=messages, **kwargs)
+
+ def analyze_video(
+ self,
+ video_input: Union[str, BinaryIO],
+ system_prompt: str = "Please analyze this video carefully.",
+ content_type: str = "video/mp4",
+ **kwargs) -> ChatMessage:
+ """Analyze video content using the configured multimodal model."""
+ messages = self.prepare_media_message(video_input, "video", content_type, system_prompt)
+ return self.__call__(messages=messages, **kwargs)
diff --git a/sdk/nexent/core/prompts/analyze_audio_en.yaml b/sdk/nexent/core/prompts/analyze_audio_en.yaml
new file mode 100644
index 000000000..eee0bb060
--- /dev/null
+++ b/sdk/nexent/core/prompts/analyze_audio_en.yaml
@@ -0,0 +1,13 @@
+# Audio Understanding Prompt Templates
+
+system_prompt: |-
+ The user has asked a question: {{ query }}. Please analyze this audio from the perspective of answering this question, within 300 words.
+
+ **Audio Analysis Requirements:**
+ 1. Focus on speech, sound events, tone, timing, and other audio content relevant to the user's question
+ 2. If speech is present, summarize or transcribe the key spoken content when possible
+ 3. Keep the answer concise and grounded in observable audio evidence
+ 4. Avoid guessing identities or facts that cannot be inferred from the audio
+
+user_prompt: |
+ Please listen to this audio and describe it from the perspective of answering the user's question.
diff --git a/sdk/nexent/core/prompts/analyze_audio_zh.yaml b/sdk/nexent/core/prompts/analyze_audio_zh.yaml
new file mode 100644
index 000000000..ae6f1fa0d
--- /dev/null
+++ b/sdk/nexent/core/prompts/analyze_audio_zh.yaml
@@ -0,0 +1,13 @@
+# 音频理解 Prompt 模板
+
+system_prompt: |-
+ 用户提出的问题是:{{ query }}。请从回答该问题的角度分析这段音频,控制在 300 字以内。
+
+ **音频分析要求:**
+ 1. 关注与用户问题相关的语音、声音事件、语气、节奏和其他音频内容
+ 2. 如果包含人声,请尽可能总结或转写关键口语内容
+ 3. 回答要简洁,并基于音频中可观察到的信息
+ 4. 不要猜测无法从音频中判断的身份或事实
+
+user_prompt: |
+ 请仔细聆听这段音频,并从回答用户问题的角度进行描述。
diff --git a/sdk/nexent/core/prompts/analyze_video_en.yaml b/sdk/nexent/core/prompts/analyze_video_en.yaml
new file mode 100644
index 000000000..7834ca7f3
--- /dev/null
+++ b/sdk/nexent/core/prompts/analyze_video_en.yaml
@@ -0,0 +1,13 @@
+# Video Understanding Prompt Templates
+
+system_prompt: |-
+ The user has asked a question: {{ query }}. Please analyze this video from the perspective of answering this question, within 300 words.
+
+ **Video Analysis Requirements:**
+ 1. Focus on scenes, actions, objects, people, visible text, and temporal changes relevant to the user's question
+ 2. Mention important audio cues only when they help answer the question
+ 3. Keep the answer concise, structured, and grounded in visible or audible evidence
+ 4. Avoid over-interpreting intent or facts that cannot be inferred from the video
+
+user_prompt: |
+ Please watch this video and describe it from the perspective of answering the user's question.
diff --git a/sdk/nexent/core/prompts/analyze_video_zh.yaml b/sdk/nexent/core/prompts/analyze_video_zh.yaml
new file mode 100644
index 000000000..e83a1676d
--- /dev/null
+++ b/sdk/nexent/core/prompts/analyze_video_zh.yaml
@@ -0,0 +1,13 @@
+# 视频理解 Prompt 模板
+
+system_prompt: |-
+ 用户提出的问题是:{{ query }}。请从回答该问题的角度分析这段视频,控制在 300 字以内。
+
+ **视频分析要求:**
+ 1. 关注与用户问题相关的场景、动作、物体、人物、可见文字和时间变化
+ 2. 只有在有助于回答问题时,才补充重要的音频线索
+ 3. 回答要简洁、有条理,并基于视频中可见或可听的信息
+ 4. 不要过度推断无法从视频中判断的意图或事实
+
+user_prompt: |
+ 请仔细观看这段视频,并从回答用户问题的角度进行描述。
diff --git a/sdk/nexent/core/tools/__init__.py b/sdk/nexent/core/tools/__init__.py
index 851690f16..a640cb5ff 100644
--- a/sdk/nexent/core/tools/__init__.py
+++ b/sdk/nexent/core/tools/__init__.py
@@ -19,6 +19,8 @@
from .terminal_tool import TerminalTool
from .analyze_text_file_tool import AnalyzeTextFileTool
from .analyze_image_tool import AnalyzeImageTool
+from .analyze_audio_tool import AnalyzeAudioTool
+from .analyze_video_tool import AnalyzeVideoTool
from .run_skill_script_tool import run_skill_script
from .read_skill_md_tool import read_skill_md
from .read_skill_config_tool import read_skill_config
@@ -47,6 +49,8 @@
"TerminalTool",
"AnalyzeTextFileTool",
"AnalyzeImageTool",
+ "AnalyzeAudioTool",
+ "AnalyzeVideoTool",
"run_skill_script",
"read_skill_md",
"read_skill_config"
diff --git a/sdk/nexent/core/tools/analyze_audio_tool.py b/sdk/nexent/core/tools/analyze_audio_tool.py
new file mode 100644
index 000000000..c7509a6c2
--- /dev/null
+++ b/sdk/nexent/core/tools/analyze_audio_tool.py
@@ -0,0 +1,169 @@
+"""
+Analyze Audio Tool
+
+Analyze audio using the configured video understanding model.
+Supports audio from S3, HTTP, and HTTPS URLs.
+"""
+
+import logging
+from io import BytesIO
+from typing import List
+
+from jinja2 import StrictUndefined, Template
+from pydantic import Field
+from smolagents.tools import Tool
+
+from ...core.models import OpenAIVLModel
+from ...core.utils.observer import MessageObserver, ProcessType
+from ...core.utils.prompt_template_utils import get_prompt_template
+from ...core.utils.tools_common_message import ToolCategory, ToolSign
+from ...multi_modal.load_save_object import LoadSaveObjectManager
+from ...multi_modal.utils import detect_content_type_from_bytes
+from ...storage import MinIOStorageClient
+
+logger = logging.getLogger("analyze_audio_tool")
+
+
+class AnalyzeAudioTool(Tool):
+ """Tool for understanding and analyzing audio using the video understanding model."""
+
+ name = "analyze_audio"
+ description = (
+ "This tool uses the configured video understanding model to understand audio based on your query and then returns an audio analysis result.\n"
+ "It is used to understand and analyze multiple audio files, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
+ "HTTP, and HTTPS URLs.\n"
+ "Use this tool when you want to retrieve information contained in audio and provide the audio URL and your query."
+ )
+ description_zh = (
+ "使用视频理解模型,根据你的提示词来理解音频,并返回音频分析结果。"
+ "可用于理解和分析多个音频文件,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
+ )
+
+ inputs = {
+ "audio_urls_list": {
+ "type": "array",
+ "description": "List of audio URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.",
+ "description_zh": "列表形式输入音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。"
+ },
+ "query": {
+ "type": "string",
+ "description": "User's question to guide the audio analysis",
+ "description_zh": "用户用于指导音频分析的问题"
+ }
+ }
+
+ init_param_descriptions = {
+ "observer": {"description": "Message observer"},
+ "vlm_model": {"description": "The video understanding model to use"},
+ "storage_client": {"description": "Storage client for downloading files"},
+ "validate_url_access": {
+ "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)"
+ }
+ }
+ output_type = "array"
+ category = ToolCategory.MULTIMODAL.value
+ tool_sign = ToolSign.MULTIMODAL_OPERATION.value
+
+ def __init__(
+ self,
+ observer: MessageObserver = Field(
+ description="Message observer",
+ default=None,
+ exclude=True),
+ vlm_model: OpenAIVLModel = Field(
+ description="The video understanding model to use",
+ default=None,
+ exclude=True),
+ storage_client: MinIOStorageClient = Field(
+ description="Storage client for downloading files from S3 URLs, HTTP URLs, and HTTPS URLs.",
+ default=None,
+ exclude=True),
+ validate_url_access: callable = Field(
+ description="Callback function to validate URL access permissions",
+ default=None,
+ exclude=True)
+ ):
+ super().__init__()
+ self.observer = observer
+ self.vlm_model = vlm_model
+ self.storage_client = storage_client
+ self._is_chinese = bool(observer and observer.lang == "zh")
+
+ validate_callback = None
+ if validate_url_access is not None and callable(validate_url_access):
+ validate_callback = validate_url_access
+ self.mm = LoadSaveObjectManager(
+ storage_client=self.storage_client,
+ validate_url_access=validate_callback
+ )
+ self.forward = self.mm.load_object(
+ input_names=["audio_urls_list"])(self._forward_impl)
+
+ self.running_prompt_zh = "正在分析音频..."
+ self.running_prompt_en = "Analyzing audio..."
+
+ def _validate_audio_capable_model(self) -> None:
+ """Fail early for SiliconFlow models that are known not to accept audio input."""
+ client_kwargs = getattr(self.vlm_model, "client_kwargs", {}) or {}
+ base_url = client_kwargs.get("base_url", "") if isinstance(client_kwargs, dict) else ""
+ model_id = str(getattr(self.vlm_model, "model_id", "") or "")
+
+ if "siliconflow" in str(base_url).lower() and model_id and "omni" not in model_id.lower():
+ raise ValueError(
+ "The selected video understanding model does not support audio input on SiliconFlow. "
+ "Please choose a Qwen3-Omni model for analyze_audio."
+ )
+
+ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]:
+ """Analyze audio files and return one result per audio input."""
+ if self.vlm_model is None:
+ error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试"
+ error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again."
+ error_msg = error_msg_zh if self._is_chinese else error_msg_en
+ logger.error(error_msg)
+ raise Exception(error_msg)
+ self._validate_audio_capable_model()
+
+ if self.observer:
+ running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en
+ self.observer.add_message("", ProcessType.TOOL, running_prompt)
+
+ if audio_urls_list is None:
+ raise ValueError("audio_urls cannot be None")
+ if not isinstance(audio_urls_list, list):
+ raise ValueError("audio_urls must be a list of bytes")
+ if not audio_urls_list:
+ raise ValueError("audio_urls must contain at least one audio file")
+
+ language = self.observer.lang if self.observer else "en"
+ prompts = get_prompt_template(
+ template_type='analyze_audio', language=language)
+ system_prompt = Template(
+ prompts['system_prompt'], undefined=StrictUndefined).render({'query': query})
+
+ try:
+ analysis_results: List[str] = []
+ for index, audio_bytes in enumerate(audio_urls_list, start=1):
+ logger.info(f"Analyzing audio #{index}, query: {query}")
+ content_type = detect_content_type_from_bytes(audio_bytes)
+ if not content_type.startswith("audio/"):
+ content_type = "audio/mpeg"
+ audio_stream = BytesIO(audio_bytes)
+ try:
+ response = self.vlm_model.analyze_audio(
+ audio_input=audio_stream,
+ system_prompt=system_prompt,
+ content_type=content_type
+ )
+ except Exception as e:
+ error_msg_zh = f"音频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。"
+ error_msg_en = f"Failed to analyze audio {index}: {str(e)}. Please check if the video understanding model is configured correctly."
+ error_msg = error_msg_zh if self._is_chinese else error_msg_en
+ raise Exception(error_msg)
+
+ analysis_results.append(response.content)
+
+ return analysis_results
+ except Exception as e:
+ logger.error(f"Error analyzing audio: {str(e)}", exc_info=True)
+ raise Exception(f"Error analyzing audio: {str(e)}")
diff --git a/sdk/nexent/core/tools/analyze_image_tool.py b/sdk/nexent/core/tools/analyze_image_tool.py
index 3851a896b..f7640a9dc 100644
--- a/sdk/nexent/core/tools/analyze_image_tool.py
+++ b/sdk/nexent/core/tools/analyze_image_tool.py
@@ -24,17 +24,17 @@
class AnalyzeImageTool(Tool):
- """Tool for understanding and analyzing image using a visual language model"""
+ """Tool for understanding and analyzing images using the image understanding model."""
name = "analyze_image"
description = (
- "This tool uses a visual language model to understand images based on your query and then returns a description of the image.\n"
+ "This tool uses the configured image understanding model to understand images based on your query and then returns a description of the image.\n"
"It is used to understand and analyze multiple images, with image sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
"HTTP, and HTTPS URLs.\n"
"Use this tool when you want to retrieve information contained in an image and provide the image's URL and your query."
)
- description_zh = "使用视觉语言模型,根据你的提示词来理解图像,并返回图像的描述。可用于理解和分析多张图片,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
+ description_zh = "使用图片理解模型,根据你的提示词来理解图像,并返回图像的描述。可用于理解和分析多张图片,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
inputs = {
"image_urls_list": {
@@ -54,7 +54,7 @@ class AnalyzeImageTool(Tool):
"description": "Message observer"
},
"vlm_model": {
- "description": "The VLM model to use"
+ "description": "The image understanding model to use"
},
"storage_client": {
"description": "Storage client for downloading files"
@@ -74,7 +74,7 @@ def __init__(
default=None,
exclude=True),
vlm_model: OpenAIVLModel = Field(
- description="The VLM model to use",
+ description="The image understanding model to use",
default=None,
exclude=True),
storage_client: MinIOStorageClient = Field(
@@ -130,10 +130,10 @@ def _forward_impl(self, image_urls_list: List[bytes], query: str) -> List[str]:
Raises:
Exception: If the image cannot be downloaded or analyzed.
"""
- # Check if VLM model is available
+ # Check if the image understanding model is available.
if self.vlm_model is None:
- error_msg_zh = "视觉语言模型(VLM)未配置,请联系管理员配置VLM模型后重试"
- error_msg_en = "Vision Language Model (VLM) is not configured. Please contact your administrator to configure the VLM model and try again."
+ error_msg_zh = "图片理解模型未配置,请联系管理员配置图片理解模型后重试"
+ error_msg_en = "Image understanding model is not configured. Please contact your administrator to configure the image understanding model and try again."
error_msg = error_msg_zh if self._is_chinese else error_msg_en
logger.error(error_msg)
raise Exception(error_msg)
@@ -170,8 +170,8 @@ def _forward_impl(self, image_urls_list: List[bytes], query: str) -> List[str]:
system_prompt=system_prompt
)
except Exception as e:
- error_msg_zh = f"图片{index}分析失败: {str(e)}。请检查VLM模型配置是否正确。"
- error_msg_en = f"Failed to analyze image {index}: {str(e)}. Please check if the VLM model is configured correctly."
+ error_msg_zh = f"图片{index}分析失败: {str(e)}。请检查图片理解模型配置是否正确。"
+ error_msg_en = f"Failed to analyze image {index}: {str(e)}. Please check if the image understanding model is configured correctly."
error_msg = error_msg_zh if self._is_chinese else error_msg_en
raise Exception(error_msg)
diff --git a/sdk/nexent/core/tools/analyze_video_tool.py b/sdk/nexent/core/tools/analyze_video_tool.py
new file mode 100644
index 000000000..3dc033551
--- /dev/null
+++ b/sdk/nexent/core/tools/analyze_video_tool.py
@@ -0,0 +1,156 @@
+"""
+Analyze Video Tool
+
+Analyze videos using the configured video understanding model.
+Supports videos from S3, HTTP, and HTTPS URLs.
+"""
+
+import logging
+from io import BytesIO
+from typing import List
+
+from jinja2 import StrictUndefined, Template
+from pydantic import Field
+from smolagents.tools import Tool
+
+from ...core.models import OpenAIVLModel
+from ...core.utils.observer import MessageObserver, ProcessType
+from ...core.utils.prompt_template_utils import get_prompt_template
+from ...core.utils.tools_common_message import ToolCategory, ToolSign
+from ...multi_modal.load_save_object import LoadSaveObjectManager
+from ...multi_modal.utils import detect_content_type_from_bytes
+from ...storage import MinIOStorageClient
+
+logger = logging.getLogger("analyze_video_tool")
+
+
+class AnalyzeVideoTool(Tool):
+ """Tool for understanding and analyzing videos using the video understanding model."""
+
+ name = "analyze_video"
+ description = (
+ "This tool uses the configured video understanding model to understand videos based on your query and then returns a video analysis result.\n"
+ "It is used to understand and analyze multiple videos, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
+ "HTTP, and HTTPS URLs.\n"
+ "Use this tool when you want to retrieve information contained in a video and provide the video's URL and your query."
+ )
+ description_zh = (
+ "使用视频理解模型,根据你的提示词来理解视频,并返回视频分析结果。"
+ "可用于理解和分析多个视频,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
+ )
+
+ inputs = {
+ "video_urls_list": {
+ "type": "array",
+ "description": "List of video URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.",
+ "description_zh": "列表形式输入视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。"
+ },
+ "query": {
+ "type": "string",
+ "description": "User's question to guide the video analysis",
+ "description_zh": "用户用于指导视频分析的问题"
+ }
+ }
+
+ init_param_descriptions = {
+ "observer": {"description": "Message observer"},
+ "vlm_model": {"description": "The video understanding model to use"},
+ "storage_client": {"description": "Storage client for downloading files"},
+ "validate_url_access": {
+ "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)"
+ }
+ }
+ output_type = "array"
+ category = ToolCategory.MULTIMODAL.value
+ tool_sign = ToolSign.MULTIMODAL_OPERATION.value
+
+ def __init__(
+ self,
+ observer: MessageObserver = Field(
+ description="Message observer",
+ default=None,
+ exclude=True),
+ vlm_model: OpenAIVLModel = Field(
+ description="The video understanding model to use",
+ default=None,
+ exclude=True),
+ storage_client: MinIOStorageClient = Field(
+ description="Storage client for downloading files from S3 URLs, HTTP URLs, and HTTPS URLs.",
+ default=None,
+ exclude=True),
+ validate_url_access: callable = Field(
+ description="Callback function to validate URL access permissions",
+ default=None,
+ exclude=True)
+ ):
+ super().__init__()
+ self.observer = observer
+ self.vlm_model = vlm_model
+ self.storage_client = storage_client
+ self._is_chinese = bool(observer and observer.lang == "zh")
+
+ validate_callback = None
+ if validate_url_access is not None and callable(validate_url_access):
+ validate_callback = validate_url_access
+ self.mm = LoadSaveObjectManager(
+ storage_client=self.storage_client,
+ validate_url_access=validate_callback
+ )
+ self.forward = self.mm.load_object(
+ input_names=["video_urls_list"])(self._forward_impl)
+
+ self.running_prompt_zh = "正在分析视频..."
+ self.running_prompt_en = "Analyzing video..."
+
+ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]:
+ """Analyze videos and return one result per video input."""
+ if self.vlm_model is None:
+ error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试"
+ error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again."
+ error_msg = error_msg_zh if self._is_chinese else error_msg_en
+ logger.error(error_msg)
+ raise Exception(error_msg)
+
+ if self.observer:
+ running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en
+ self.observer.add_message("", ProcessType.TOOL, running_prompt)
+
+ if video_urls_list is None:
+ raise ValueError("video_urls cannot be None")
+ if not isinstance(video_urls_list, list):
+ raise ValueError("video_urls must be a list of bytes")
+ if not video_urls_list:
+ raise ValueError("video_urls must contain at least one video")
+
+ language = self.observer.lang if self.observer else "en"
+ prompts = get_prompt_template(
+ template_type='analyze_video', language=language)
+ system_prompt = Template(
+ prompts['system_prompt'], undefined=StrictUndefined).render({'query': query})
+
+ try:
+ analysis_results: List[str] = []
+ for index, video_bytes in enumerate(video_urls_list, start=1):
+ logger.info(f"Analyzing video #{index}, query: {query}")
+ content_type = detect_content_type_from_bytes(video_bytes)
+ if not content_type.startswith("video/"):
+ content_type = "video/mp4"
+ video_stream = BytesIO(video_bytes)
+ try:
+ response = self.vlm_model.analyze_video(
+ video_input=video_stream,
+ system_prompt=system_prompt,
+ content_type=content_type
+ )
+ except Exception as e:
+ error_msg_zh = f"视频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。"
+ error_msg_en = f"Failed to analyze video {index}: {str(e)}. Please check if the video understanding model is configured correctly."
+ error_msg = error_msg_zh if self._is_chinese else error_msg_en
+ raise Exception(error_msg)
+
+ analysis_results.append(response.content)
+
+ return analysis_results
+ except Exception as e:
+ logger.error(f"Error analyzing video: {str(e)}", exc_info=True)
+ raise Exception(f"Error analyzing video: {str(e)}")
diff --git a/sdk/nexent/core/utils/prompt_template_utils.py b/sdk/nexent/core/utils/prompt_template_utils.py
index ad06e9119..24b273876 100644
--- a/sdk/nexent/core/utils/prompt_template_utils.py
+++ b/sdk/nexent/core/utils/prompt_template_utils.py
@@ -17,6 +17,14 @@
LANGUAGE["ZH"]: 'core/prompts/analyze_image_zh.yaml',
LANGUAGE["EN"]: 'core/prompts/analyze_image_en.yaml'
},
+ 'analyze_audio': {
+ LANGUAGE["ZH"]: 'core/prompts/analyze_audio_zh.yaml',
+ LANGUAGE["EN"]: 'core/prompts/analyze_audio_en.yaml'
+ },
+ 'analyze_video': {
+ LANGUAGE["ZH"]: 'core/prompts/analyze_video_zh.yaml',
+ LANGUAGE["EN"]: 'core/prompts/analyze_video_en.yaml'
+ },
'analyze_file': {
LANGUAGE["ZH"]: 'core/prompts/analyze_file_zh.yaml',
LANGUAGE["EN"]: 'core/prompts/analyze_file_en.yaml'
@@ -30,6 +38,8 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw
Args:
template_type: Template type, supports the following values:
- 'analyze_image': Analyze image template
+ - 'analyze_audio': Analyze audio template
+ - 'analyze_video': Analyze video template
- 'analyze_file': Analyze file template (for text files)
language: Language code ('zh' or 'en')
**kwargs: Additional parameters, for agent type need to pass is_manager parameter
@@ -52,4 +62,4 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw
# Read and return template content
with open(absolute_template_path, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f)
\ No newline at end of file
+ return yaml.safe_load(f)
diff --git a/sdk/nexent/multi_modal/utils.py b/sdk/nexent/multi_modal/utils.py
index e118f6940..bcd6cdd35 100644
--- a/sdk/nexent/multi_modal/utils.py
+++ b/sdk/nexent/multi_modal/utils.py
@@ -34,10 +34,10 @@ def is_url(url: str) -> Optional[UrlType]:
if url.startswith("https://"):
return "https"
- if url.startswith("s3://"):
- bucket_path = url.replace("s3://", "", 1)
+ if url.startswith("s3://") or url.startswith("s3:/"):
+ bucket_path = url.replace("s3://", "", 1) if url.startswith("s3://") else url.replace("s3:/", "", 1).lstrip("/")
bucket_object = bucket_path.split("/", 1)
- if len(bucket_object) == 2 and all(bucket_object):
+ if len(bucket_object) == 2 and all(bucket_object) and ":" not in bucket_object[0]:
return "s3"
return None
@@ -321,6 +321,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
Supports formats:
- s3://bucket/key
+ - s3:/bucket/key
- /bucket/key (MinIO path format)
Args:
@@ -335,11 +336,16 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
if not s3_url:
raise ValueError("S3 URL cannot be empty")
- if s3_url.startswith('s3://'):
- parts = s3_url.replace('s3://', '').split('/', 1)
+ if s3_url.startswith('s3://') or s3_url.startswith('s3:/'):
+ normalized_url = (
+ s3_url.replace('s3://', '', 1)
+ if s3_url.startswith('s3://')
+ else s3_url.replace('s3:/', '', 1).lstrip('/')
+ )
+ parts = normalized_url.split('/', 1)
if len(parts) == 2:
bucket, object_name = parts
- if not bucket or not object_name:
+ if not bucket or not object_name or ":" in bucket:
raise ValueError(f"Invalid s3:// URL format: {s3_url}")
return bucket, object_name
raise ValueError(f"Invalid s3:// URL format: {s3_url}")
@@ -351,4 +357,4 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
return bucket, object_name
raise ValueError(f"Invalid path format: {s3_url}")
- raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...")
\ No newline at end of file
+ raise ValueError(f"Unrecognized S3 URL format: {s3_url[:50]}...")
diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py
index 5817fbe27..20340f2ea 100644
--- a/test/backend/agents/test_create_agent_info.py
+++ b/test/backend/agents/test_create_agent_info.py
@@ -30,6 +30,21 @@ class ValidationError(Exception):
pass
+class MCPConnectionError(Exception):
+ """Mock MCPConnectionError for testing."""
+ pass
+
+
+class NotFoundException(Exception):
+ """Mock NotFoundException for testing."""
+ pass
+
+
+class ToolExecutionException(Exception):
+ """Mock ToolExecutionException for testing."""
+ pass
+
+
consts_model_module = types.ModuleType("consts.model")
consts_model_module.HistoryItem = HistoryItem
sys.modules["consts.model"] = consts_model_module
@@ -37,6 +52,9 @@ class ValidationError(Exception):
# Mock consts.exceptions module with ValidationError
consts_exceptions_module = types.ModuleType("consts.exceptions")
consts_exceptions_module.ValidationError = ValidationError
+consts_exceptions_module.MCPConnectionError = MCPConnectionError
+consts_exceptions_module.NotFoundException = NotFoundException
+consts_exceptions_module.ToolExecutionException = ToolExecutionException
sys.modules["consts.exceptions"] = consts_exceptions_module
# Also add model and exceptions to consts module attributes
@@ -165,7 +183,9 @@ def _create_stub_module(name: str, **attrs):
services_module = _create_stub_module("services")
sys.modules['services'] = services_module
sys.modules['services.image_service'] = _create_stub_module(
- "services.image_service", get_vlm_model=MagicMock(return_value="stub_vlm")
+ "services.image_service",
+ get_vlm_model=MagicMock(return_value="stub_vlm"),
+ get_video_understanding_model=MagicMock(return_value="stub_video_vlm"),
)
sys.modules['services.memory_config_service'] = MagicMock()
# Extend services hierarchy with additional stubs
@@ -250,6 +270,7 @@ def _create_stub_module(name: str, **attrs):
_extract_url_from_card,
_build_external_agent_config,
_get_external_a2a_agents,
+ _build_internal_s3_url,
_format_minio_files_for_content,
_convert_history_with_minio_files,
)
@@ -727,6 +748,48 @@ async def test_create_tool_config_list_with_analyze_image_tool(self):
assert "validate_url_access" in mock_tool_instance.metadata
assert callable(mock_tool_instance.metadata["validate_url_access"])
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "class_name,tool_name",
+ [
+ ("AnalyzeAudioTool", "analyze_audio"),
+ ("AnalyzeVideoTool", "analyze_video"),
+ ],
+ )
+ async def test_create_tool_config_list_with_audio_video_tools(self, class_name, tool_name):
+ """Ensure audio/video tools receive video understanding model metadata."""
+ mock_tool_instance = MagicMock()
+ mock_tool_instance.class_name = class_name
+ mock_tool_config.return_value = mock_tool_instance
+
+ with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \
+ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \
+ patch('backend.agents.create_agent_info.get_video_understanding_model') as mock_get_video_model, \
+ patch('backend.agents.create_agent_info.minio_client', new_callable=MagicMock):
+
+ mock_search_tools.return_value = [
+ {
+ "class_name": class_name,
+ "name": tool_name,
+ "description": "Analyze media tool",
+ "inputs": "string",
+ "output_type": "string",
+ "params": [{"name": "prompt", "default": "describe"}],
+ "source": "local",
+ "usage": None
+ }
+ ]
+ mock_get_video_model.return_value = "mock_video_model"
+
+ result = await create_tool_config_list("agent_1", "tenant_1", "user_1")
+
+ assert len(result) == 1
+ assert result[0] is mock_tool_instance
+ mock_get_video_model.assert_called_once_with(tenant_id="tenant_1")
+ assert mock_tool_instance.metadata["vlm_model"] == "mock_video_model"
+ assert "storage_client" in mock_tool_instance.metadata
+ assert callable(mock_tool_instance.metadata["validate_url_access"])
+
@pytest.mark.asyncio
async def test_create_tool_config_list_with_analyze_text_file_tool(self):
"""Ensure AnalyzeTextFileTool receives text-specific metadata."""
@@ -3297,6 +3360,26 @@ async def test_create_agent_run_info_is_need_auth_true_includes_token(self):
class TestJoinMinioFileDescriptionToQuery:
"""Tests for the join_minio_file_description_to_query function"""
+ def test_build_internal_s3_url_prefers_object_name(self):
+ file = {
+ "object_name": "attachments/user/image.png",
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+
+ result = _build_internal_s3_url(file)
+
+ assert result.endswith("/attachments/user/image.png")
+ assert result.startswith("s3://")
+
+ def test_build_internal_s3_url_rejects_blob_preview_url(self):
+ file = {
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+
+ assert _build_internal_s3_url(file) == ""
+
@pytest.mark.asyncio
async def test_join_minio_file_description_to_query_with_files(self):
"""Test case with file descriptions"""
@@ -3345,6 +3428,40 @@ async def test_join_minio_file_description_to_query_no_descriptions(self):
assert result == "test query"
+ @pytest.mark.asyncio
+ async def test_join_minio_file_description_to_query_prefers_object_name_over_blob_url(self):
+ """Uploaded images should be exposed to internal tools through MinIO, not browser blob URLs."""
+ minio_files = [
+ {
+ "object_name": "attachments/user/image.png",
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+ ]
+ query = "describe the image"
+
+ result = await join_minio_file_description_to_query(minio_files, query)
+
+ assert "blob:http" not in result
+ assert "File name: image.png" in result
+ assert "attachments/user/image.png" in result
+ assert "S3 URL: s3://" in result
+
+ @pytest.mark.asyncio
+ async def test_join_minio_file_description_to_query_skips_blob_only_file(self):
+ """Browser-only preview URLs cannot be used by internal tools."""
+ minio_files = [
+ {
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+ ]
+ query = "describe the image"
+
+ result = await join_minio_file_description_to_query(minio_files, query)
+
+ assert result == query
+
@pytest.mark.asyncio
async def test_join_minio_file_description_to_query_deduplication_current(self):
"""Test that duplicate files in current message are de-duplicated by URL"""
@@ -4455,6 +4572,21 @@ def test_format_minio_files_for_content_single_file_without_presigned_url(self):
result = _format_minio_files_for_content(minio_files)
assert result == "\n[Attached files]:\n - file.txt: s3:/bucket/file.txt"
+ def test_format_minio_files_for_content_uses_object_name_for_blob_url(self):
+ """Use uploaded object_name instead of browser-only blob preview URL."""
+ minio_files = [
+ {
+ "object_name": "attachments/user/image.png",
+ "url": "blob:http://localhost:3000/preview",
+ "name": "image.png",
+ }
+ ]
+
+ result = _format_minio_files_for_content(minio_files)
+
+ assert "blob:http" not in result
+ assert "attachments/user/image.png" in result
+
def test_format_minio_files_for_content_multiple_files(self):
"""Test case for multiple files"""
minio_files = [
diff --git a/test/backend/services/providers/test_silicon_provider.py b/test/backend/services/providers/test_silicon_provider.py
index b947040c3..c596643b2 100644
--- a/test/backend/services/providers/test_silicon_provider.py
+++ b/test/backend/services/providers/test_silicon_provider.py
@@ -66,7 +66,13 @@ async def test_get_models_vlm_success(self, mocker: MockFixture):
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
- {"id": "gpt-4v", "name": "GPT-4 Vision"},
+ {"id": "deepseek-ai/DeepSeek-R1", "name": "DeepSeek R1"},
+ {"id": "Qwen/Qwen2.5-VL-72B-Instruct", "name": "Qwen2.5 VL"},
+ {"id": "OpenGVLab/InternVL2-26B", "name": "InternVL2 26B"},
+ {"id": "Pro/moonshotai/Kimi-K2.6", "name": "Kimi K2.6"},
+ {"id": "Pro/moonshotai/Kimi-K2.5", "name": "Kimi K2.5"},
+ {"id": "Qwen/Qwen3.6-27B", "name": "Qwen3.6 27B"},
+ {"id": "Qwen/Qwen3.6-35B-A3B", "name": "Qwen3.6 35B A3B"},
]
}
mock_response.raise_for_status = MagicMock()
@@ -95,10 +101,66 @@ async def test_get_models_vlm_success(self, mocker: MockFixture):
result = await provider.get_models(provider_config)
- assert len(result) == 1
- assert result[0]["id"] == "gpt-4v"
- assert result[0]["model_type"] == "vlm"
- assert result[0]["model_tag"] == "chat"
+ assert [model["id"] for model in result] == [
+ "Qwen/Qwen2.5-VL-72B-Instruct",
+ "OpenGVLab/InternVL2-26B",
+ "Pro/moonshotai/Kimi-K2.6",
+ "Pro/moonshotai/Kimi-K2.5",
+ "Qwen/Qwen3.6-27B",
+ "Qwen/Qwen3.6-35B-A3B",
+ ]
+ assert all(model["model_type"] == "vlm" for model in result)
+ assert all(model["model_tag"] == "chat" for model in result)
+
+ @pytest.mark.asyncio
+ async def test_get_models_vlm3_only_returns_omni_models(self, mocker: MockFixture):
+ """Test that SiliconFlow video understanding models are restricted to Omni models."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "data": [
+ {"id": "Qwen/Qwen3-VL-32B-Instruct", "name": "Qwen3 VL"},
+ {"id": "Qwen/Qwen3-Omni-30B-A3B-Instruct", "name": "Qwen3 Omni Instruct"},
+ {"id": "Qwen/Qwen3-Omni-30B-A3B-Thinking", "name": "Qwen3 Omni Thinking"},
+ {"id": "Qwen/Qwen3-Omni-30B-A3B-Captioner", "name": "Qwen3 Omni Captioner"},
+ {"id": "zai-org/GLM-4.5V", "name": "GLM 4.5V"},
+ ]
+ }
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_response
+
+ mock_cm = MagicMock()
+ mock_cm.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_cm.__aexit__ = AsyncMock(return_value=None)
+
+ mocker.patch(
+ "backend.services.providers.silicon_provider.httpx.AsyncClient",
+ return_value=mock_cm
+ )
+ mocker.patch(
+ "backend.services.providers.silicon_provider.SILICON_GET_URL",
+ "https://api.siliconflow.com/v1/models"
+ )
+
+ provider = SiliconModelProvider()
+ provider_config = {
+ "model_type": "vlm3",
+ "api_key": "test-api-key"
+ }
+
+ result = await provider.get_models(provider_config)
+
+ assert [model["id"] for model in result] == [
+ "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "Qwen/Qwen3-Omni-30B-A3B-Thinking",
+ "Qwen/Qwen3-Omni-30B-A3B-Captioner",
+ ]
+ assert all(model["model_type"] == "vlm3" for model in result)
+ assert all(model["model_tag"] == "chat" for model in result)
+ call_args = mock_client.get.call_args
+ assert "sub_type=chat" in call_args[0][0]
@pytest.mark.asyncio
async def test_get_models_embedding_success(self, mocker: MockFixture):
diff --git a/test/backend/services/test_image_service.py b/test/backend/services/test_image_service.py
index 1de8d49fd..34f24568c 100644
--- a/test/backend/services/test_image_service.py
+++ b/test/backend/services/test_image_service.py
@@ -13,10 +13,17 @@
helpers_env = bootstrap_test_env()
helpers_env["mock_const"].DATA_PROCESS_SERVICE = "http://mock-data-process-service"
-helpers_env["mock_const"].MODEL_CONFIG_MAPPING = {"vlm": "vlm_model_config"}
+helpers_env["mock_const"].MODEL_CONFIG_MAPPING = {
+ "vlm": "vlm_model_config",
+ "vlm3": "video_model_config",
+}
mock_const = helpers_env["mock_const"]
-from services.image_service import get_vlm_model, proxy_image_impl
+from services.image_service import get_image_understanding_model, get_video_understanding_model, get_vlm_model, proxy_image_impl
+
+image_service_module = sys.modules[get_vlm_model.__module__]
+if "services" in sys.modules:
+ setattr(sys.modules["services"], "image_service", image_service_module)
# Sample test data
test_url = "https://example.com/image.jpg"
@@ -50,7 +57,7 @@ async def test_proxy_image_impl_success():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -85,7 +92,7 @@ async def test_proxy_image_impl_remote_error():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -118,7 +125,7 @@ async def test_proxy_image_impl_500_error():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -146,7 +153,7 @@ async def test_proxy_image_impl_connection_exception():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function - should raise the exception
@@ -178,7 +185,7 @@ async def test_proxy_image_impl_with_special_chars():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -213,7 +220,7 @@ async def test_proxy_image_impl_json_parse_error():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function - should raise the exception
@@ -253,7 +260,7 @@ async def test_proxy_image_impl_different_status_codes():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function
@@ -289,7 +296,7 @@ async def test_proxy_image_impl_url_encoding():
mock_client_session.__aenter__.return_value = mock_session
# Patch the ClientSession
- with patch('services.image_service.aiohttp.ClientSession') as mock_session_class:
+ with patch.object(image_service_module.aiohttp, 'ClientSession') as mock_session_class:
mock_session_class.return_value = mock_client_session
# Test the function with encoded URL
@@ -305,10 +312,10 @@ async def test_proxy_image_impl_url_encoding():
assert f"url={encoded_url}" in called_url
-@patch('services.image_service.OpenAIVLModel')
-@patch('services.image_service.MessageObserver')
-@patch('services.image_service.get_model_name_from_config')
-@patch('services.image_service.tenant_config_manager')
+@patch.object(image_service_module, 'OpenAIVLModel')
+@patch.object(image_service_module, 'MessageObserver')
+@patch.object(image_service_module, 'get_model_name_from_config')
+@patch.object(image_service_module, 'tenant_config_manager')
def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model):
"""Ensure get_vlm_model builds OpenAIVLModel with tenant config."""
mock_config = {
@@ -324,7 +331,7 @@ def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name,
result = get_vlm_model("tenant-1")
mock_tenant_config_manager.get_model_config.assert_called_once_with(
- key=mock_const.MODEL_CONFIG_MAPPING["vlm"],
+ key="vlm_model_config",
tenant_id="tenant-1"
)
mock_message_observer.assert_called_once_with()
@@ -342,10 +349,10 @@ def test_get_vlm_model_success(mock_tenant_config_manager, mock_get_model_name,
assert result == mock_model_instance
-@patch('services.image_service.OpenAIVLModel')
-@patch('services.image_service.MessageObserver')
-@patch('services.image_service.get_model_name_from_config')
-@patch('services.image_service.tenant_config_manager')
+@patch.object(image_service_module, 'OpenAIVLModel')
+@patch.object(image_service_module, 'MessageObserver')
+@patch.object(image_service_module, 'get_model_name_from_config')
+@patch.object(image_service_module, 'tenant_config_manager')
def test_get_vlm_model_with_none_config(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model):
"""Return None when tenant config is None."""
mock_tenant_config_manager.get_model_config.return_value = None
@@ -359,3 +366,40 @@ def test_get_vlm_model_with_none_config(mock_tenant_config_manager, mock_get_mod
# OpenAIVLModel should not be called when config is None
mock_openai_vl_model.assert_not_called()
assert result is None
+
+
+@patch.object(image_service_module, 'get_vlm_model')
+def test_get_image_understanding_model_uses_first_multimodal_slot(mock_get_vlm_model):
+ """Ensure the image understanding alias keeps using the first multimodal slot."""
+ mock_get_vlm_model.return_value = "image-understanding-model"
+
+ result = get_image_understanding_model("tenant-1")
+
+ mock_get_vlm_model.assert_called_once_with(tenant_id="tenant-1")
+ assert result == "image-understanding-model"
+
+
+@patch.object(image_service_module, 'OpenAIVLModel')
+@patch.object(image_service_module, 'MessageObserver')
+@patch.object(image_service_module, 'get_model_name_from_config')
+@patch.object(image_service_module, 'tenant_config_manager')
+def test_get_video_understanding_model_success(mock_tenant_config_manager, mock_get_model_name, mock_message_observer, mock_openai_vl_model):
+ """Ensure video understanding tools use the third multimodal model slot."""
+ mock_config = {
+ "base_url": "https://mock-video-api",
+ "api_key": "secret",
+ "model_name": "video-model"
+ }
+ mock_tenant_config_manager.get_model_config.return_value = mock_config
+ mock_get_model_name.return_value = "video-model"
+ mock_model_instance = MagicMock()
+ mock_openai_vl_model.return_value = mock_model_instance
+
+ result = get_video_understanding_model("tenant-1")
+
+ mock_tenant_config_manager.get_model_config.assert_called_once_with(
+ key="video_model_config",
+ tenant_id="tenant-1"
+ )
+ mock_openai_vl_model.assert_called_once()
+ assert result == mock_model_instance
diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py
index 6e504e90a..bd56cce40 100644
--- a/test/backend/services/test_model_management_service.py
+++ b/test/backend/services/test_model_management_service.py
@@ -660,17 +660,11 @@ async def test_batch_create_models_for_tenant_flow():
existing = [
{"model_id": "del-id", "model_repo": "silicon", "model_name": "delete"},
- {"model_id": "keep-id", "model_repo": "silicon", "model_name": "keep"},
+ {"model_id": "keep-id", "model_repo": "silicon", "model_name": "keep", "max_tokens": 1024},
]
- def get_by_display(display_name, tenant_id):
- if display_name == "silicon/keep":
- return {"model_id": "keep-id", "max_tokens": 1024}
- return None
-
with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing) as mock_get_existing, \
mock.patch.object(svc, "delete_model_record") as mock_delete, \
- mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \
mock.patch.object(svc, "update_model_record") as mock_update, \
mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"prepared": True})) as mock_prep, \
mock.patch.object(svc, "create_model_record") as mock_create:
@@ -679,13 +673,35 @@ def get_by_display(display_name, tenant_id):
mock_get_existing.assert_called_once_with("t1", "silicon", "llm")
mock_delete.assert_called_once_with("del-id", "u1", "t1")
- mock_get_by_display.assert_any_call("silicon/keep", "t1")
mock_update.assert_called_once_with(
"keep-id", {"max_tokens": 4096}, "u1")
mock_prep.assert_awaited()
mock_create.assert_called_once()
+@pytest.mark.asyncio
+async def test_batch_create_models_uses_requested_type_for_each_model():
+ svc = import_svc()
+
+ batch_payload = {
+ "provider": "silicon",
+ "type": "vlm",
+ "models": [
+ {"id": "Qwen/Qwen2.5-VL-72B-Instruct", "model_type": "llm", "max_tokens": 4096},
+ ],
+ "api_key": "k",
+ }
+
+ with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \
+ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"prepared": True})) as mock_prep, \
+ mock.patch.object(svc, "create_model_record"):
+
+ await svc.batch_create_models_for_tenant("u1", "t1", batch_payload)
+
+ prepared_model = mock_prep.call_args.kwargs["model"]
+ assert prepared_model["model_type"] == "vlm"
+
+
@pytest.mark.asyncio
async def test_batch_create_models_max_tokens_update():
"""Test batch_create_models updates max_tokens when display_name exists and max_tokens changed (covers lines 160->173, 168->171)"""
@@ -702,22 +718,16 @@ async def test_batch_create_models_max_tokens_update():
"api_key": "k",
}
- def get_by_display(display_name, tenant_id):
- if display_name == "silicon/model1":
- # Different from new value
- return {"model_id": "id1", "max_tokens": 4096}
- elif display_name == "silicon/model2":
- return {"model_id": "id2", "max_tokens": 4096} # Same as new value
- elif display_name == "silicon/model3":
- # Existing has value, new is None
- return {"model_id": "id3", "max_tokens": 2048}
- return None
+ existing = [
+ {"model_id": "id1", "model_repo": "silicon", "model_name": "model1", "max_tokens": 4096},
+ {"model_id": "id2", "model_repo": "silicon", "model_name": "model2", "max_tokens": 4096},
+ {"model_id": "id3", "model_repo": "silicon", "model_name": "model3", "max_tokens": 2048},
+ ]
- with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[]), \
+ with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing), \
mock.patch.object(svc, "delete_model_record"), \
mock.patch.object(svc, "split_repo_name", side_effect=lambda x: ("silicon", x.split("/")[1] if "/" in x else x)), \
mock.patch.object(svc, "add_repo_to_name", side_effect=lambda r, n: f"{r}/{n}"), \
- mock.patch.object(svc, "get_model_by_display_name", side_effect=get_by_display) as mock_get_by_display, \
mock.patch.object(svc, "update_model_record") as mock_update, \
mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \
mock.patch.object(svc, "create_model_record", return_value=True):
diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py
index 3cbdcee2b..44bcb7681 100644
--- a/test/backend/services/test_tool_configuration_service.py
+++ b/test/backend/services/test_tool_configuration_service.py
@@ -15,6 +15,54 @@
minio_client_mock = MagicMock()
sys.modules['boto3'] = boto3_mock
+fastmcp_mock = types.ModuleType('fastmcp')
+fastmcp_mock.__path__ = []
+
+
+class MockFastMcpClient:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+ def is_connected(self):
+ return True
+
+ async def call_tool(self, *args, **kwargs):
+ return MagicMock()
+
+
+class MockSSETransport:
+ def __init__(self, *args, **kwargs):
+ pass
+
+
+class MockStreamableHttpTransport:
+ def __init__(self, *args, **kwargs):
+ pass
+
+
+fastmcp_mock.Client = MockFastMcpClient
+fastmcp_client_mock = types.ModuleType('fastmcp.client')
+fastmcp_client_mock.__path__ = []
+fastmcp_transports_mock = types.ModuleType('fastmcp.client.transports')
+fastmcp_transports_mock.SSETransport = MockSSETransport
+fastmcp_transports_mock.StreamableHttpTransport = MockStreamableHttpTransport
+sys.modules['fastmcp'] = fastmcp_mock
+sys.modules['fastmcp.client'] = fastmcp_client_mock
+sys.modules['fastmcp.client.transports'] = fastmcp_transports_mock
+
+mcpadapt_mock = types.ModuleType('mcpadapt')
+mcpadapt_mock.__path__ = []
+mcpadapt_smolagents_adapter_mock = types.ModuleType('mcpadapt.smolagents_adapter')
+mcpadapt_smolagents_adapter_mock._sanitize_function_name = lambda name: name
+sys.modules['mcpadapt'] = mcpadapt_mock
+sys.modules['mcpadapt.smolagents_adapter'] = mcpadapt_smolagents_adapter_mock
+
# Patch smolagents and its sub-modules before importing consts.model to avoid ImportError
mock_smolagents = MagicMock()
sys.modules['smolagents'] = mock_smolagents
@@ -323,7 +371,7 @@ def validate(self):
'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_vector_db_core': MagicMock(),
'ElasticSearchService': MagicMock()},
'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()},
- 'image_service': {'get_vlm_model': MagicMock()}
+ 'image_service': {'get_vlm_model': MagicMock(), 'get_video_understanding_model': MagicMock()}
}
for service_name, attrs in services_modules.items():
service_module = types.ModuleType(f'services.{service_name}')
@@ -354,6 +402,7 @@ def validate(self):
patch('services.tenant_config_service.build_knowledge_name_mapping',
MagicMock()).start()
patch('services.image_service.get_vlm_model', MagicMock()).start()
+patch('services.image_service.get_video_understanding_model', MagicMock()).start()
patch('backend.database.knowledge_db.get_knowledge_name_map_by_index_names', MagicMock()).start()
patch('backend.services.tool_configuration_service.get_embedding_model_by_index_name', MagicMock()).start()
@@ -2676,6 +2725,63 @@ def test_validate_local_tool_analyze_image_missing_user(self, mock_get_class):
)
+class TestValidateLocalToolAnalyzeAudioVideo:
+ """Test cases for _validate_local_tool with analyze_audio/analyze_video tools."""
+
+ @pytest.mark.parametrize("tool_name", ["analyze_audio", "analyze_video"])
+ @patch('backend.services.tool_configuration_service.minio_client')
+ @patch('backend.services.tool_configuration_service.get_video_understanding_model')
+ @patch('backend.services.tool_configuration_service._get_tool_class_by_name')
+ @patch('backend.services.tool_configuration_service.inspect.signature')
+ def test_validate_local_tool_analyze_audio_video_success(
+ self, mock_signature, mock_get_class, mock_get_video_model, mock_minio_client, tool_name):
+ mock_tool_class = Mock()
+ mock_tool_instance = Mock()
+ mock_tool_instance.forward.return_value = f"{tool_name} result"
+ mock_tool_class.return_value = mock_tool_instance
+ mock_get_class.return_value = mock_tool_class
+ mock_get_video_model.return_value = "mock_video_model"
+
+ mock_sig = Mock()
+ mock_sig.parameters = {}
+ mock_signature.return_value = mock_sig
+
+ from backend.services.tool_configuration_service import _validate_local_tool
+
+ result = _validate_local_tool(
+ tool_name,
+ {"media": "bytes"},
+ {"prompt": "describe"},
+ "tenant1",
+ "user1"
+ )
+
+ assert result == f"{tool_name} result"
+ mock_get_video_model.assert_called_once_with(tenant_id="tenant1")
+ call_kwargs = mock_tool_class.call_args.kwargs
+ assert call_kwargs["vlm_model"] == "mock_video_model"
+ assert "storage_client" in call_kwargs
+ assert callable(call_kwargs["validate_url_access"])
+ mock_tool_instance.forward.assert_called_once_with(media="bytes")
+
+ @pytest.mark.parametrize("tool_name", ["analyze_audio", "analyze_video"])
+ @patch('backend.services.tool_configuration_service._get_tool_class_by_name')
+ def test_validate_local_tool_analyze_audio_video_missing_tenant(self, mock_get_class, tool_name):
+ mock_get_class.return_value = Mock()
+
+ from backend.services.tool_configuration_service import _validate_local_tool
+
+ with pytest.raises(ToolExecutionException,
+ match=f"Tenant ID and User ID are required for {tool_name} validation"):
+ _validate_local_tool(
+ tool_name,
+ {"media": "bytes"},
+ {"prompt": "describe"},
+ None,
+ "user1"
+ )
+
+
class TestValidateLocalToolDatamateSearchTool:
"""Test cases for _validate_local_tool function with datamate_search_tool"""
diff --git a/test/common/test_mocks.py b/test/common/test_mocks.py
index c87b52859..c57941780 100644
--- a/test/common/test_mocks.py
+++ b/test/common/test_mocks.py
@@ -112,6 +112,8 @@ def setup_common_mocks():
"multiEmbedding": "MULTI_EMBEDDING_ID",
"rerank": "RERANK_ID",
"vlm": "VLM_ID",
+ "vlm2": "VLM2_ID",
+ "vlm3": "VLM3_ID",
"stt": "STT_ID",
"tts": "TTS_ID"
}
diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py
index 9853b9eca..2d588b62d 100644
--- a/test/sdk/core/agents/test_nexent_agent.py
+++ b/test/sdk/core/agents/test_nexent_agent.py
@@ -2429,6 +2429,52 @@ def test_create_local_tool_analyze_image(self, nexent_agent_instance):
assert call_kwargs["param1"] == "value1"
assert result == mock_tool_instance
+ @pytest.mark.parametrize(
+ "class_name,tool_name",
+ [
+ ("AnalyzeAudioTool", "analyze_audio"),
+ ("AnalyzeVideoTool", "analyze_video"),
+ ],
+ )
+ def test_create_local_tool_analyze_audio_video(self, nexent_agent_instance, class_name, tool_name):
+ """Test successful audio/video analysis tool creation."""
+ mock_tool_class = MagicMock()
+ mock_tool_instance = MagicMock()
+ mock_tool_class.return_value = mock_tool_instance
+
+ tool_config = ToolConfig(
+ class_name=class_name,
+ name=tool_name,
+ description="desc",
+ inputs="{}",
+ output_type="string",
+ params={"param1": "value1"},
+ source="local",
+ metadata={
+ "vlm_model": ["video-understanding-model"],
+ "storage_client": "storage"
+ }
+ )
+
+ original_value = nexent_agent.__dict__.get(class_name)
+ nexent_agent.__dict__[class_name] = mock_tool_class
+
+ try:
+ result = nexent_agent_instance.create_local_tool(tool_config)
+ finally:
+ if original_value is not None:
+ nexent_agent.__dict__[class_name] = original_value
+ elif class_name in nexent_agent.__dict__:
+ del nexent_agent.__dict__[class_name]
+
+ mock_tool_class.assert_called_once()
+ call_kwargs = mock_tool_class.call_args[1]
+ assert call_kwargs["observer"] == nexent_agent_instance.observer
+ assert call_kwargs["vlm_model"] == ["video-understanding-model"]
+ assert call_kwargs["storage_client"] == "storage"
+ assert call_kwargs["param1"] == "value1"
+ assert result == mock_tool_instance
+
def test_create_local_tool_analyze_text_file_with_validate_url_access_none(self, nexent_agent_instance):
"""Test AnalyzeTextFileTool creation with validate_url_access not in metadata (None)."""
mock_tool_class = MagicMock()
diff --git a/test/sdk/core/models/test_openai_vlm.py b/test/sdk/core/models/test_openai_vlm.py
index f1db49380..4f7104290 100644
--- a/test/sdk/core/models/test_openai_vlm.py
+++ b/test/sdk/core/models/test_openai_vlm.py
@@ -62,7 +62,13 @@ def vl_model_instance():
"""Return an OpenAIVLModel instance with minimal viable attributes for tests."""
observer = MagicMock()
- model = ImportedOpenAIVLModel(observer=observer, ssl_verify=True)
+ model = ImportedOpenAIVLModel(
+ observer=observer,
+ model_id="dummy-model",
+ api_key="dummy-key",
+ api_base="https://example.test",
+ ssl_verify=True,
+ )
# Inject dummy attributes required by the method under test
model.model_id = "dummy-model"
@@ -321,3 +327,55 @@ def test_analyze_image_calls_prepare_image_message(vl_model_instance, tmp_path):
# Verify prepare_image_message was called with correct arguments
mock_prepare.assert_called_once_with(str(test_image), custom_prompt)
+
+
+def test_prepare_media_message_audio(vl_model_instance):
+ audio_stream = MagicMock()
+ audio_stream.read.return_value = b"audio bytes"
+
+ messages = vl_model_instance.prepare_media_message(
+ audio_stream,
+ media_type="audio",
+ content_type="audio/mpeg",
+ system_prompt="Listen carefully",
+ )
+
+ assert messages[0]["content"][0]["type"] == "audio_url"
+ assert messages[0]["content"][0]["audio_url"]["url"].startswith("data:audio/mpeg;base64,")
+ assert messages[0]["content"][1] == {"type": "text", "text": "Listen carefully"}
+
+
+def test_prepare_media_message_video(vl_model_instance):
+ video_stream = MagicMock()
+ video_stream.read.return_value = b"video bytes"
+
+ messages = vl_model_instance.prepare_media_message(
+ video_stream,
+ media_type="video",
+ content_type="video/mp4",
+ system_prompt="Watch carefully",
+ )
+
+ assert messages[0]["content"][0]["type"] == "video_url"
+ assert messages[0]["content"][0]["video_url"]["url"].startswith("data:video/mp4;base64,")
+ assert messages[0]["content"][0]["video_url"]["max_frames"] == 16
+ assert messages[0]["content"][0]["video_url"]["fps"] == 1
+ assert messages[0]["content"][1] == {"type": "text", "text": "Watch carefully"}
+
+
+def test_analyze_audio_calls_prepare_media_message(vl_model_instance):
+ with patch.object(vl_model_instance, "prepare_media_message", return_value=[{"role": "user", "content": "test"}]) as mock_prepare:
+ vl_model_instance.__call__ = MagicMock(return_value=MagicMock())
+
+ vl_model_instance.analyze_audio("audio.mp3", system_prompt="Analyze", content_type="audio/mpeg")
+
+ mock_prepare.assert_called_once_with("audio.mp3", "audio", "audio/mpeg", "Analyze")
+
+
+def test_analyze_video_calls_prepare_media_message(vl_model_instance):
+ with patch.object(vl_model_instance, "prepare_media_message", return_value=[{"role": "user", "content": "test"}]) as mock_prepare:
+ vl_model_instance.__call__ = MagicMock(return_value=MagicMock())
+
+ vl_model_instance.analyze_video("video.mp4", system_prompt="Analyze", content_type="video/mp4")
+
+ mock_prepare.assert_called_once_with("video.mp4", "video", "video/mp4", "Analyze")
diff --git a/test/sdk/core/tools/test_analyze_audio_video_tool.py b/test/sdk/core/tools/test_analyze_audio_video_tool.py
new file mode 100644
index 000000000..94401b61d
--- /dev/null
+++ b/test/sdk/core/tools/test_analyze_audio_video_tool.py
@@ -0,0 +1,119 @@
+from types import SimpleNamespace
+from unittest.mock import MagicMock
+
+import pytest
+
+from sdk.nexent.core.tools import analyze_audio_tool, analyze_video_tool
+from sdk.nexent.core.tools.analyze_audio_tool import AnalyzeAudioTool
+from sdk.nexent.core.tools.analyze_video_tool import AnalyzeVideoTool
+from sdk.nexent.core.utils.observer import MessageObserver, ProcessType
+
+
+@pytest.fixture
+def mock_storage_client():
+ class DummyStorage:
+ pass
+
+ return DummyStorage()
+
+
+@pytest.fixture
+def mock_vlm_model():
+ return MagicMock()
+
+
+@pytest.fixture
+def observer_en():
+ observer = MagicMock(spec=MessageObserver)
+ observer.lang = "en"
+ return observer
+
+
+def test_analyze_audio_uses_video_understanding_model(observer_en, mock_vlm_model, mock_storage_client, monkeypatch):
+ calls = []
+
+ def _fake_get_prompt(template_type, language=None, **_):
+ calls.append((template_type, language))
+ return {"system_prompt": "Analyze audio for {{ query }}"}
+
+ monkeypatch.setattr(analyze_audio_tool, "get_prompt_template", _fake_get_prompt)
+ mock_vlm_model.analyze_audio.return_value = SimpleNamespace(content="audio result")
+ tool = AnalyzeAudioTool(
+ observer=observer_en,
+ vlm_model=mock_vlm_model,
+ storage_client=mock_storage_client,
+ )
+
+ result = tool._forward_impl([b"ID3audio-bytes"], "what happened?")
+
+ assert result == ["audio result"]
+ assert calls == [("analyze_audio", "en")]
+ mock_vlm_model.analyze_audio.assert_called_once()
+ call_kwargs = mock_vlm_model.analyze_audio.call_args.kwargs
+ assert hasattr(call_kwargs["audio_input"], "read")
+ assert call_kwargs["content_type"].startswith("audio/")
+ observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing audio...")
+
+
+def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_storage_client):
+ vlm_model = SimpleNamespace(
+ model_id="Qwen/Qwen3-VL-32B-Instruct",
+ client_kwargs={"base_url": "https://api.siliconflow.cn/v1"},
+ )
+ tool = AnalyzeAudioTool(
+ observer=observer_en,
+ vlm_model=vlm_model,
+ storage_client=mock_storage_client,
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ tool._forward_impl([b"ID3audio-bytes"], "what happened?")
+
+ assert "Please choose a Qwen3-Omni model" in str(exc_info.value)
+
+
+def test_analyze_video_uses_video_understanding_model(observer_en, mock_vlm_model, mock_storage_client, monkeypatch):
+ calls = []
+
+ def _fake_get_prompt(template_type, language=None, **_):
+ calls.append((template_type, language))
+ return {"system_prompt": "Analyze video for {{ query }}"}
+
+ monkeypatch.setattr(analyze_video_tool, "get_prompt_template", _fake_get_prompt)
+ mock_vlm_model.analyze_video.return_value = SimpleNamespace(content="video result")
+ tool = AnalyzeVideoTool(
+ observer=observer_en,
+ vlm_model=mock_vlm_model,
+ storage_client=mock_storage_client,
+ )
+
+ result = tool._forward_impl([b"\x00\x00\x00\x18ftypmp42video-bytes"], "what happened?")
+
+ assert result == ["video result"]
+ assert calls == [("analyze_video", "en")]
+ mock_vlm_model.analyze_video.assert_called_once()
+ call_kwargs = mock_vlm_model.analyze_video.call_args.kwargs
+ assert hasattr(call_kwargs["video_input"], "read")
+ assert call_kwargs["content_type"].startswith("video/")
+ observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing video...")
+
+
+@pytest.mark.parametrize(
+ "tool_class,input_name,error_text",
+ [
+ (AnalyzeAudioTool, "audio_urls_list", "Video understanding model is not configured"),
+ (AnalyzeVideoTool, "video_urls_list", "Video understanding model is not configured"),
+ ],
+)
+def test_analyze_audio_video_require_video_understanding_model(
+ tool_class, input_name, error_text, observer_en, mock_storage_client):
+ tool = tool_class(
+ observer=observer_en,
+ vlm_model=None,
+ storage_client=mock_storage_client,
+ )
+
+ with pytest.raises(Exception) as exc_info:
+ tool._forward_impl(**{input_name: [b"media"], "query": "question"})
+
+ assert error_text in str(exc_info.value)
diff --git a/test/sdk/core/tools/test_analyze_image_tool.py b/test/sdk/core/tools/test_analyze_image_tool.py
index a8598a8ad..63be0ac54 100644
--- a/test/sdk/core/tools/test_analyze_image_tool.py
+++ b/test/sdk/core/tools/test_analyze_image_tool.py
@@ -136,7 +136,7 @@ def test_forward_impl_vlm_model_none(self, observer_en, mock_storage_client):
with pytest.raises(Exception) as exc_info:
tool._forward_impl([b"img"], "question")
- assert "Vision Language Model (VLM) is not configured" in str(
+ assert "Image understanding model is not configured" in str(
exc_info.value)
def test_forward_impl_vlm_model_none_chinese(self, observer_zh, mock_storage_client):
@@ -150,7 +150,7 @@ def test_forward_impl_vlm_model_none_chinese(self, observer_zh, mock_storage_cli
with pytest.raises(Exception) as exc_info:
tool._forward_impl([b"img"], "问题")
- assert "视觉语言模型(VLM)未配置" in str(exc_info.value)
+ assert "图片理解模型未配置" in str(exc_info.value)
def test_forward_impl_observer_none_uses_english(self, mock_vlm_model, mock_storage_client):
"""Test that English is used when observer is None."""
@@ -353,7 +353,7 @@ def test_observer_add_message_not_called_when_none(self, mock_vlm_model, mock_st
def test_tool_name_and_description(self, tool):
"""Test that tool name and description are set correctly."""
assert tool.name == "analyze_image"
- assert "visual language model" in tool.description.lower()
+ assert "image understanding model" in tool.description.lower()
assert "image" in tool.description.lower()
def test_tool_inputs_schema(self, tool):
diff --git a/test/sdk/core/utils/test_prompt_template_utils.py b/test/sdk/core/utils/test_prompt_template_utils.py
index c0a3ad634..a50929b8d 100644
--- a/test/sdk/core/utils/test_prompt_template_utils.py
+++ b/test/sdk/core/utils/test_prompt_template_utils.py
@@ -61,6 +61,28 @@ def test_get_prompt_template_analyze_image_en(self, mock_yaml_load, mock_file):
# Verify result
assert result == {"system_prompt": "Test prompt", "user_prompt": "User prompt"}
+ @pytest.mark.parametrize(
+ "template_type,language,expected_file",
+ [
+ ("analyze_audio", "en", "prompts/analyze_audio_en.yaml"),
+ ("analyze_audio", "zh", "prompts/analyze_audio_zh.yaml"),
+ ("analyze_video", "en", "prompts/analyze_video_en.yaml"),
+ ("analyze_video", "zh", "prompts/analyze_video_zh.yaml"),
+ ],
+ )
+ @patch('builtins.open', new_callable=mock_open, read_data='system_prompt: "Test prompt"\nuser_prompt: "User prompt"')
+ @patch('yaml.safe_load')
+ def test_get_prompt_template_analyze_audio_video(
+ self, mock_yaml_load, mock_file, template_type, language, expected_file):
+ """Test get_prompt_template for audio/video templates."""
+ mock_yaml_load.return_value = {"system_prompt": "Test prompt", "user_prompt": "User prompt"}
+
+ result = get_prompt_template(template_type=template_type, language=language)
+
+ call_args = mock_file.call_args[0]
+ assert expected_file in call_args[0].replace('\\', '/')
+ assert result == {"system_prompt": "Test prompt", "user_prompt": "User prompt"}
+
@patch('builtins.open', new_callable=mock_open, read_data='system_prompt: "Test prompt"')
@patch('yaml.safe_load')
@patch('sdk.nexent.core.utils.prompt_template_utils.LANGUAGE', {'ZH': 'zh', 'EN': 'en'})
@@ -174,4 +196,4 @@ def test_get_prompt_template_path_resolution(self, mock_yaml_load, mock_file):
assert mock_file.called
call_args = mock_file.call_args[0]
# Path should be absolute or contain the expected template file
- assert 'analyze_image_en.yaml' in call_args[0]
\ No newline at end of file
+ assert 'analyze_image_en.yaml' in call_args[0]
diff --git a/test/sdk/multi_modal/test_load_save_object.py b/test/sdk/multi_modal/test_load_save_object.py
index 92425791c..1670e6a9d 100644
--- a/test/sdk/multi_modal/test_load_save_object.py
+++ b/test/sdk/multi_modal/test_load_save_object.py
@@ -26,6 +26,23 @@ def test_get_client_requires_initialized_storage():
manager._get_client()
+def test_s3_single_slash_url_supported():
+ assert lso.is_url("s3:/bucket/path/to/image.png") == "s3"
+ assert lso.parse_s3_url("s3:/bucket/path/to/image.png") == (
+ "bucket",
+ "path/to/image.png",
+ )
+
+
+def test_s3_blob_preview_url_rejected():
+ assert lso.is_url("s3:/blob:http://localhost:3000/preview") is None
+
+
+def test_parse_s3_blob_preview_url_rejected():
+ with pytest.raises(ValueError, match="Invalid s3:// URL format"):
+ lso.parse_s3_url("s3:/blob:http://localhost:3000/preview")
+
+
def test_download_file_from_http(monkeypatch):
manager = make_manager()
From 5a85a5ba94acfef67aeb43ebf1af829a3ab996a6 Mon Sep 17 00:00:00 2001
From: 827dls <1670704430@qq.com>
Date: Sun, 17 May 2026 17:20:11 +0800
Subject: [PATCH 2/7] =?UTF-8?q?=E8=A1=A5=E5=85=85test=E6=B5=8B=E8=AF=95?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../services/test_config_sync_service.py | 44 ++++++++++++++++---
.../test_config_sync_service_voice.py | 32 +++++++++++++-
.../services/test_model_management_service.py | 7 ++-
3 files changed, 75 insertions(+), 8 deletions(-)
diff --git a/test/backend/services/test_config_sync_service.py b/test/backend/services/test_config_sync_service.py
index 0748a71b7..9928d0315 100644
--- a/test/backend/services/test_config_sync_service.py
+++ b/test/backend/services/test_config_sync_service.py
@@ -1,4 +1,6 @@
import sys
+import types
+import importlib
from unittest.mock import patch, MagicMock, call
import pytest
@@ -22,6 +24,31 @@
minio_config_mock = MagicMock()
minio_config_mock.validate = MagicMock()
+if 'consts.const' in sys.modules and not hasattr(sys.modules['consts.const'], 'APP_DESCRIPTION'):
+ sys.modules.pop('consts.const', None)
+if 'consts' in sys.modules and not hasattr(sys.modules['consts'], '__path__'):
+ sys.modules.pop('consts', None)
+
+database_client_module = types.ModuleType('database.client')
+database_client_module.MinioClient = MagicMock()
+database_client_module.minio_client = minio_client_mock
+database_client_module.as_dict = MagicMock(side_effect=lambda value: value)
+database_client_module.db_client = MagicMock()
+database_client_module.db_client.clean_string_values = MagicMock(side_effect=lambda value: value)
+database_client_module.get_db_session = MagicMock()
+sys.modules['database.client'] = database_client_module
+database_package = sys.modules.get('database') or importlib.import_module('database')
+setattr(database_package, 'client', database_client_module)
+database_model_management_module = types.ModuleType('database.model_management_db')
+database_model_management_module.get_model_by_model_id = MagicMock()
+database_model_management_module.get_model_id_by_display_name = MagicMock()
+database_model_management_module.get_model_records = MagicMock(return_value=[])
+sys.modules['database.model_management_db'] = database_model_management_module
+setattr(database_package, 'model_management_db', database_model_management_module)
+backend_database_client_module = sys.modules.get('backend.database.client')
+if backend_database_client_module is not None and not hasattr(backend_database_client_module, 'minio_client'):
+ backend_database_client_module.minio_client = minio_client_mock
+
patch('nexent.storage.storage_client_factory.create_storage_client_from_config',
return_value=storage_client_mock).start()
patch('nexent.storage.minio_config.MinIOStorageConfig',
@@ -29,7 +56,7 @@
patch('backend.database.client.MinioClient',
return_value=minio_client_mock).start()
patch('database.client.MinioClient', return_value=minio_client_mock).start()
-patch('backend.database.client.minio_client', minio_client_mock).start()
+patch('backend.database.client.minio_client', minio_client_mock, create=True).start()
patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start()
# Import backend modules after all patches are applied
@@ -52,14 +79,17 @@ def service_mocks():
with patch('backend.services.config_sync_service.tenant_config_manager') as mock_tenant_config_manager, \
patch('backend.services.config_sync_service.get_env_key') as mock_get_env_key, \
patch('backend.services.config_sync_service.safe_value') as mock_safe_value, \
+ patch('backend.services.config_sync_service.get_model_records') as mock_get_model_records, \
patch('backend.services.config_sync_service.get_model_id_by_display_name') as mock_get_model_id, \
patch('backend.services.config_sync_service.get_model_name_from_config') as mock_get_model_name, \
patch('backend.services.config_sync_service.logger') as mock_logger:
+ mock_get_model_records.return_value = []
yield {
'tenant_config_manager': mock_tenant_config_manager,
'get_env_key': mock_get_env_key,
'safe_value': mock_safe_value,
+ 'get_model_records': mock_get_model_records,
'get_model_id': mock_get_model_id,
'get_model_name': mock_get_model_name,
'logger': mock_logger
@@ -1336,6 +1366,8 @@ def side_effect(config_key, tenant_id=None):
"MULTI_EMBEDDING_ID": {},
"RERANK_ID": {},
"VLM_ID": {},
+ "VLM2_ID": {},
+ "VLM3_ID": {},
"STT_ID": {},
"TTS_ID": {}
}
@@ -1348,7 +1380,7 @@ def side_effect(config_key, tenant_id=None):
# Assert
assert isinstance(result, dict)
- assert len(result) == 7 # All model types should be present
+ assert len(result) == 9 # All model types should be present
# Verify successful configs
assert result["llm"]["displayName"] == "GPT-4"
@@ -1372,20 +1404,20 @@ def test_build_models_config_all_failures(self, service_mocks):
# Assert
assert isinstance(result, dict)
# All model types should still be present with empty configs
- assert len(result) == 7
+ assert len(result) == 9
# All configs should be empty due to exceptions
- for model_key in ["llm", "embedding", "multiEmbedding", "rerank", "vlm", "stt", "tts"]:
+ for model_key in ["llm", "embedding", "multiEmbedding", "rerank", "vlm", "vlm2", "vlm3", "stt", "tts"]:
assert result[model_key]["name"] == ""
assert result[model_key]["displayName"] == ""
assert result[model_key]["apiConfig"]["apiKey"] == ""
assert result[model_key]["apiConfig"]["modelUrl"] == ""
# Verify that logger.warning was called for each model type
- assert service_mocks['logger'].warning.call_count == 7
+ assert service_mocks['logger'].warning.call_count == 9
warning_calls = service_mocks['logger'].warning.call_args_list
expected_configs = ["LLM_ID", "EMBEDDING_ID", "MULTI_EMBEDDING_ID",
- "RERANK_ID", "VLM_ID", "STT_ID", "TTS_ID"]
+ "RERANK_ID", "VLM_ID", "VLM2_ID", "VLM3_ID", "STT_ID", "TTS_ID"]
for i, config_key in enumerate(expected_configs):
assert f"Failed to get config for {config_key}: Database completely down" in warning_calls[
i][0][0]
diff --git a/test/backend/services/test_config_sync_service_voice.py b/test/backend/services/test_config_sync_service_voice.py
index fcfd531f1..1a3144036 100644
--- a/test/backend/services/test_config_sync_service_voice.py
+++ b/test/backend/services/test_config_sync_service_voice.py
@@ -3,6 +3,8 @@
These tests cover the STT specific fields in save_config_impl.
"""
import sys
+import types
+import importlib
from unittest.mock import patch, MagicMock
import pytest
@@ -22,6 +24,31 @@
minio_config_mock = MagicMock()
minio_config_mock.validate = MagicMock()
+if 'consts.const' in sys.modules and not hasattr(sys.modules['consts.const'], 'APP_DESCRIPTION'):
+ sys.modules.pop('consts.const', None)
+if 'consts' in sys.modules and not hasattr(sys.modules['consts'], '__path__'):
+ sys.modules.pop('consts', None)
+
+database_client_module = types.ModuleType('database.client')
+database_client_module.MinioClient = MagicMock()
+database_client_module.minio_client = minio_client_mock
+database_client_module.as_dict = MagicMock(side_effect=lambda value: value)
+database_client_module.db_client = MagicMock()
+database_client_module.db_client.clean_string_values = MagicMock(side_effect=lambda value: value)
+database_client_module.get_db_session = MagicMock()
+sys.modules['database.client'] = database_client_module
+database_package = sys.modules.get('database') or importlib.import_module('database')
+setattr(database_package, 'client', database_client_module)
+database_model_management_module = types.ModuleType('database.model_management_db')
+database_model_management_module.get_model_by_model_id = MagicMock()
+database_model_management_module.get_model_id_by_display_name = MagicMock()
+database_model_management_module.get_model_records = MagicMock(return_value=[])
+sys.modules['database.model_management_db'] = database_model_management_module
+setattr(database_package, 'model_management_db', database_model_management_module)
+backend_database_client_module = sys.modules.get('backend.database.client')
+if backend_database_client_module is not None and not hasattr(backend_database_client_module, 'minio_client'):
+ backend_database_client_module.minio_client = minio_client_mock
+
patch('nexent.storage.storage_client_factory.create_storage_client_from_config',
return_value=storage_client_mock).start()
patch('nexent.storage.minio_config.MinIOStorageConfig',
@@ -29,7 +56,7 @@
patch('backend.database.client.MinioClient',
return_value=minio_client_mock).start()
patch('database.client.MinioClient', return_value=minio_client_mock).start()
-patch('backend.database.client.minio_client', minio_client_mock).start()
+patch('backend.database.client.minio_client', minio_client_mock, create=True).start()
patch('elasticsearch.Elasticsearch', return_value=MagicMock()).start()
# Import backend modules after all patches are applied
@@ -47,14 +74,17 @@ def service_mocks():
with patch('backend.services.config_sync_service.tenant_config_manager') as mock_tenant_config_manager, \
patch('backend.services.config_sync_service.get_env_key') as mock_get_env_key, \
patch('backend.services.config_sync_service.safe_value') as mock_safe_value, \
+ patch('backend.services.config_sync_service.get_model_records') as mock_get_model_records, \
patch('backend.services.config_sync_service.get_model_id_by_display_name') as mock_get_model_id, \
patch('backend.services.config_sync_service.get_model_name_from_config') as mock_get_model_name, \
patch('backend.services.config_sync_service.logger') as mock_logger:
+ mock_get_model_records.return_value = []
yield {
'tenant_config_manager': mock_tenant_config_manager,
'get_env_key': mock_get_env_key,
'safe_value': mock_safe_value,
+ 'get_model_records': mock_get_model_records,
'get_model_id': mock_get_model_id,
'get_model_name': mock_get_model_name,
'logger': mock_logger
diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py
index bd56cce40..f85cbfe05 100644
--- a/test/backend/services/test_model_management_service.py
+++ b/test/backend/services/test_model_management_service.py
@@ -315,6 +315,11 @@ def _add_repo_to_name(model_repo, model_name):
def import_svc():
"""Import service under MinioClient patch to avoid real initialization."""
minio_client_mock = mock.MagicMock()
+ sys.modules["database"] = database_mod
+ sys.modules["database.model_management_db"] = db_mm_mod
+ setattr(database_mod, "model_management_db", db_mm_mod)
+ sys.modules.pop("backend.services.model_management_service", None)
+ sys.modules.pop("services.model_management_service", None)
with mock.patch("backend.database.client.MinioClient", return_value=minio_client_mock):
from backend.services import model_management_service as svc # type: ignore
return svc
@@ -727,7 +732,7 @@ async def test_batch_create_models_max_tokens_update():
with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=existing), \
mock.patch.object(svc, "delete_model_record"), \
mock.patch.object(svc, "split_repo_name", side_effect=lambda x: ("silicon", x.split("/")[1] if "/" in x else x)), \
- mock.patch.object(svc, "add_repo_to_name", side_effect=lambda r, n: f"{r}/{n}"), \
+ mock.patch.object(svc, "add_repo_to_name", side_effect=lambda *args, **kwargs: f"{kwargs.get('model_repo', args[0] if args else '')}/{kwargs.get('model_name', args[1] if len(args) > 1 else '')}"), \
mock.patch.object(svc, "update_model_record") as mock_update, \
mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \
mock.patch.object(svc, "create_model_record", return_value=True):
From 37b87c4e3d201bd7f641d3d9e4ef30d1559e9ca7 Mon Sep 17 00:00:00 2001
From: Dallas98 <990259227@qq.com>
Date: Wed, 27 May 2026 10:57:41 +0800
Subject: [PATCH 3/7] Chore: Update .dockerignore to remove unnecessary backend
assets
---
.dockerignore | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/.dockerignore b/.dockerignore
index 45c1def32..385a6449f 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -37,8 +37,6 @@ build/
*.tgz
# Backend
-backend/assets/*
-!backend/assets/test.wav
backend/flower_db.sqlite
uploads/
test/
@@ -60,4 +58,4 @@ assets/
.Spotlight-V100
.Trashes
ehthumbs.db
-Thumbs.db
\ No newline at end of file
+Thumbs.db
From 6986c45f5544a4144e4b1db5211a09fa9d85a137 Mon Sep 17 00:00:00 2001
From: Dallas98 <990259227@qq.com>
Date: Wed, 27 May 2026 11:25:15 +0800
Subject: [PATCH 4/7] Feat: add optional id field to SingleModelConfig
interface
---
frontend/types/modelConfig.ts | 1 +
1 file changed, 1 insertion(+)
diff --git a/frontend/types/modelConfig.ts b/frontend/types/modelConfig.ts
index d5df7459d..8f4789f6b 100644
--- a/frontend/types/modelConfig.ts
+++ b/frontend/types/modelConfig.ts
@@ -91,6 +91,7 @@ export interface TTSModelConfig extends SingleModelConfig {
// Single model configuration interface
export interface SingleModelConfig {
+ id?: number;
modelName: string;
displayName: string;
apiConfig: ModelApiConfig;
From c55c33c8f45b38804125d1e8d40286915027ad97 Mon Sep 17 00:00:00 2001
From: 827dls <1670704430@qq.com>
Date: Wed, 27 May 2026 23:06:50 +0800
Subject: [PATCH 5/7] fix_audio_tools,Support adding models individually
---
backend/services/model_health_service.py | 42 +++++++++
backend/services/model_management_service.py | 22 ++++-
.../services/providers/dashscope_provider.py | 91 ++++++++++++++++++-
.../services/providers/tokenpony_provider.py | 76 +++++++++++++++-
4 files changed, 221 insertions(+), 10 deletions(-)
diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py
index 58a0af91f..09d37fa0d 100644
--- a/backend/services/model_health_service.py
+++ b/backend/services/model_health_service.py
@@ -15,6 +15,11 @@
logger = logging.getLogger("model_health_service")
+DASHSCOPE_MODEL_FACTORY = "dashscope"
+TOKENPONY_MODEL_FACTORY = "tokenpony"
+PROVIDER_CATALOG_HEALTHCHECK_FACTORIES = {DASHSCOPE_MODEL_FACTORY, TOKENPONY_MODEL_FACTORY}
+PROVIDER_CATALOG_HEALTHCHECK_TYPES = {"vlm", "vlm2", "vlm3"}
+
def _mask_secret(value: Optional[str]) -> str:
"""Mask a secret value, showing only first and last 4 characters."""
@@ -64,6 +69,31 @@ async def _embedding_dimension_check(
raise ValueError(f"Unsupported model type: {model_type}")
+async def _provider_catalog_connectivity_check(
+ model_name: str,
+ model_type: str,
+ model_api_key: str,
+ model_factory: Optional[str],
+) -> bool:
+ """Validate provider-managed multimodal models through their model catalog."""
+ provider = (model_factory or "").lower()
+ if provider not in PROVIDER_CATALOG_HEALTHCHECK_FACTORIES:
+ return False
+
+ from services.model_provider_service import get_provider_models
+
+ model_list = await get_provider_models({
+ "provider": provider,
+ "model_type": model_type,
+ "api_key": model_api_key,
+ })
+ if not model_list or any(model.get("_error") for model in model_list):
+ return False
+
+ expected_model_id = model_name.lower()
+ return any(str(model.get("id", "")).lower() == expected_model_id for model in model_list)
+
+
async def _perform_connectivity_check(
model_name: str,
model_type: str,
@@ -135,6 +165,18 @@ async def _perform_connectivity_check(
)
connectivity = await rerank_model.connectivity_check()
elif model_type in ("vlm", "vlm2", "vlm3"):
+ if (
+ model_type in PROVIDER_CATALOG_HEALTHCHECK_TYPES
+ and (model_factory or "").lower() in PROVIDER_CATALOG_HEALTHCHECK_FACTORIES
+ ):
+ connectivity = await _provider_catalog_connectivity_check(
+ model_name=model_name,
+ model_type=model_type,
+ model_api_key=model_api_key,
+ model_factory=model_factory,
+ )
+ return connectivity
+
observer = MessageObserver()
set_monitoring_operation("connectivity_check",
display_name=display_name)
diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py
index 8f6d191fd..9f032728a 100644
--- a/backend/services/model_management_service.py
+++ b/backend/services/model_management_service.py
@@ -8,7 +8,6 @@
from database.model_management_db import (
create_model_record,
delete_model_record,
- get_model_by_display_name,
get_model_by_name_factory,
get_models_by_display_name,
get_model_records,
@@ -32,6 +31,23 @@
logger = logging.getLogger("model_management_service")
+INDEPENDENT_MULTIMODAL_MODEL_TYPES = {"vlm", "vlm2", "vlm3"}
+
+
+def _has_display_name_conflict(existing_models: List[Dict[str, Any]], model_type: Optional[str]) -> bool:
+ """Allow the three multimodal slots to share display names across slots."""
+ if not existing_models:
+ return False
+
+ if model_type in INDEPENDENT_MULTIMODAL_MODEL_TYPES:
+ return any(
+ existing.get("model_type") == model_type
+ or existing.get("model_type") not in INDEPENDENT_MULTIMODAL_MODEL_TYPES
+ for existing in existing_models
+ )
+
+ return True
+
async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict[str, Any]):
"""Create a single model record for the given tenant.
@@ -77,9 +93,9 @@ async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict
# Check display name conflict scoped by tenant
if model_data.get("display_name"):
- existing_model_by_display = get_model_by_display_name(
+ existing_models_by_display = get_models_by_display_name(
model_data["display_name"], tenant_id)
- if existing_model_by_display:
+ if _has_display_name_conflict(existing_models_by_display, model_data.get("model_type")):
logging.error(
f"Name {model_data['display_name']} is already in use, please choose another display name")
raise ValueError(
diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py
index 69096fb15..497dcfe99 100644
--- a/backend/services/providers/dashscope_provider.py
+++ b/backend/services/providers/dashscope_provider.py
@@ -6,6 +6,75 @@
from services.providers.base import AbstractModelProvider, _classify_provider_error
+DASHSCOPE_IMAGE_GENERATION_KEYWORDS = (
+ "image",
+ "wanx",
+ "aitryon",
+ "tryon",
+ "flux",
+ "stable-diffusion",
+ "sdxl",
+)
+DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS = (
+ "qwen-vl",
+ "qwen2-vl",
+ "qwen2.5-vl",
+ "qwen3-vl",
+ "qwen3.5-vl",
+ "qwen3.6-vl",
+ "-vl",
+ "vl-",
+ "vision",
+ "visual",
+ "ocr",
+ "qwen3.6",
+ "qwen-3.6",
+)
+DASHSCOPE_VIDEO_UNDERSTANDING_KEYWORDS = ("omni", "video-understanding", "video-ocr")
+
+
+def _modality_set(value) -> set:
+ if not value:
+ return set()
+ if isinstance(value, str):
+ return {value.lower()}
+ return {str(item).lower() for item in value}
+
+
+def _has_keyword(text: str, keywords: tuple) -> bool:
+ return any(keyword in text for keyword in keywords)
+
+
+def _is_dashscope_explicit_image_understanding_model(model_id: str) -> bool:
+ return _has_keyword(model_id, DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS)
+
+
+def _is_dashscope_image_generation_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool:
+ if _is_dashscope_explicit_image_understanding_model(model_id):
+ return False
+ return "image" in res_mods or _has_keyword(model_id, DASHSCOPE_IMAGE_GENERATION_KEYWORDS)
+
+
+def _is_dashscope_video_understanding_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool:
+ searchable_text = f"{model_id} {desc.lower()}"
+ if "video" in req_mods and "text" in res_mods:
+ return True
+ return _has_keyword(searchable_text, DASHSCOPE_VIDEO_UNDERSTANDING_KEYWORDS)
+
+
+def _is_dashscope_image_understanding_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool:
+ searchable_text = f"{model_id} {desc.lower()}"
+ if _is_dashscope_image_generation_model(model_id, desc, req_mods, res_mods):
+ return False
+ if _is_dashscope_video_understanding_model(model_id, desc, req_mods, res_mods):
+ return False
+ if ("image" in req_mods or "video" in req_mods) and "text" in res_mods:
+ return True
+ return _is_dashscope_explicit_image_understanding_model(model_id) or _has_keyword(
+ searchable_text, DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS
+ )
+
+
class DashScopeModelProvider(AbstractModelProvider):
"""Concrete implementation for DashScope (Aliyun) provider."""
@@ -57,6 +126,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
categorized_models = {
"chat": [], # Maps to "llm"
"vlm": [], # Maps to "vlm"
+ "vlm2": [], # Maps to image generation models
+ "vlm3": [], # Maps to video understanding models
"embedding": [], # Maps to "embedding" / "multi_embedding"
"rerank": [], # Maps to "rerank"
"tts": [], # Maps to "tts"
@@ -71,6 +142,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
metadata = model_obj.get('inference_metadata') or {}
req_mod = metadata.get('request_modality', [])
res_mod = metadata.get('response_modality', [])
+ req_mods = _modality_set(req_mod)
+ res_mods = _modality_set(res_mod)
model_obj.setdefault("object", model_obj.get("object", "model"))
model_obj.setdefault("owned_by", model_obj.get("owned_by", "dashscope"))
cleaned_model = {
@@ -107,8 +180,17 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
continue
# 5. VLM
- vision_mods = {'Image', 'Video'}
- if (set(req_mod) & vision_mods) or (set(res_mod) & vision_mods) or '视觉' in desc:
+ if _is_dashscope_video_understanding_model(m_id, desc, req_mods, res_mods):
+ cleaned_model.update({"model_tag": "chat", "model_type": "vlm3"})
+ categorized_models['vlm3'].append(cleaned_model)
+ continue
+
+ if _is_dashscope_image_generation_model(m_id, desc, req_mods, res_mods):
+ cleaned_model.update({"model_tag": "chat", "model_type": "vlm2"})
+ categorized_models['vlm2'].append(cleaned_model)
+ continue
+
+ if _is_dashscope_image_understanding_model(m_id, desc, req_mods, res_mods):
cleaned_model.update({"model_tag": "chat", "model_type": "vlm"})
categorized_models['vlm'].append(cleaned_model)
continue
@@ -124,7 +206,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
elif target_model_type in ("embedding", "multi_embedding"):
return categorized_models["embedding"]
elif target_model_type in categorized_models:
- return categorized_models[target_model_type]
+ return [
+ {**model, "model_type": target_model_type}
+ for model in categorized_models[target_model_type]
+ ]
else:
return []
except (httpx.HTTPStatusError, httpx.ConnectTimeout, httpx.ConnectError, Exception) as e:
diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py
index ab4446c1b..be2bb9c71 100644
--- a/backend/services/providers/tokenpony_provider.py
+++ b/backend/services/providers/tokenpony_provider.py
@@ -9,6 +9,64 @@
from services.providers.base import AbstractModelProvider, _classify_provider_error
+TOKENPONY_IMAGE_UNDERSTANDING_KEYWORDS = (
+ "qwen-vl",
+ "qwen2-vl",
+ "qwen2.5-vl",
+ "qwen3-vl",
+ "qwen3.5-vl",
+ "qwen3.6-vl",
+ "-vl",
+ "vl-",
+ "vision",
+ "visual",
+ "ocr",
+ "gpt-4o",
+ "qwen3.6",
+ "qwen-3.6",
+)
+TOKENPONY_IMAGE_GENERATION_KEYWORDS = (
+ "image",
+ "dall",
+ "flux",
+ "stable-diffusion",
+ "sdxl",
+ "midjourney",
+ "wanx",
+ "kolors",
+ "seedream",
+ "ideogram",
+ "recraft",
+)
+TOKENPONY_VIDEO_UNDERSTANDING_KEYWORDS = ("omni", "video")
+
+
+def _has_keyword(text: str, keywords: tuple) -> bool:
+ return any(keyword in text for keyword in keywords)
+
+
+def _is_tokenpony_explicit_image_understanding_model(model_id: str) -> bool:
+ return _has_keyword(model_id, TOKENPONY_IMAGE_UNDERSTANDING_KEYWORDS)
+
+
+def _is_tokenpony_image_generation_model(model_id: str) -> bool:
+ if _is_tokenpony_explicit_image_understanding_model(model_id):
+ return False
+ return _has_keyword(model_id, TOKENPONY_IMAGE_GENERATION_KEYWORDS)
+
+
+def _is_tokenpony_video_understanding_model(model_id: str) -> bool:
+ return _has_keyword(model_id, TOKENPONY_VIDEO_UNDERSTANDING_KEYWORDS)
+
+
+def _is_tokenpony_image_understanding_model(model_id: str) -> bool:
+ if _is_tokenpony_image_generation_model(model_id):
+ return False
+ if _is_tokenpony_video_understanding_model(model_id):
+ return False
+ return _is_tokenpony_explicit_image_understanding_model(model_id)
+
+
class TokenPonyModelProvider(AbstractModelProvider):
"""Concrete implementation for TokenPony provider."""
@@ -46,6 +104,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
categorized_models = {
"chat": [], # Maps to "llm"
"vlm": [], # Maps to "vlm"
+ "vlm2": [], # Maps to image generation models
+ "vlm3": [], # Maps to video understanding models
"embedding": [], # Maps to "embedding" / "multi_embedding"
"rerank": [], # Maps to "rerank"
"tts": [], # Maps to "tts"
@@ -86,9 +146,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
cleaned_model.update({"model_tag": "tts", "model_type": "tts"})
categorized_models['tts'].append(cleaned_model)
- # 5. VLM (Vision Language Model / Image & Video Generation)
-
- elif any(keyword in m_id for keyword in ['-vl', 'vl-', 'ocr', 'vision']):
+ # 5. Multimodal models
+ elif _is_tokenpony_video_understanding_model(m_id):
+ cleaned_model.update({"model_tag": "chat", "model_type": "vlm3"})
+ categorized_models['vlm3'].append(cleaned_model)
+ elif _is_tokenpony_image_generation_model(m_id):
+ cleaned_model.update({"model_tag": "chat", "model_type": "vlm2"})
+ categorized_models['vlm2'].append(cleaned_model)
+ elif _is_tokenpony_image_understanding_model(m_id):
cleaned_model.update({"model_tag": "chat", "model_type": "vlm"})
categorized_models['vlm'].append(cleaned_model)
@@ -104,7 +169,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
elif target_model_type in ("embedding", "multi_embedding"):
return categorized_models["embedding"]
elif target_model_type in categorized_models:
- return categorized_models[target_model_type]
+ return [
+ {**model, "model_type": target_model_type}
+ for model in categorized_models[target_model_type]
+ ]
else:
return []
From e74adb7dec1220d2e0e937116ce6e7531b178a3e Mon Sep 17 00:00:00 2001
From: 827dls <1670704430@qq.com>
Date: Thu, 28 May 2026 00:29:29 +0800
Subject: [PATCH 6/7] fix test
---
sdk/nexent/core/tools/analyze_audio_tool.py | 62 ++++----
sdk/nexent/core/tools/analyze_video_tool.py | 72 +++++----
.../providers/test_dashscope_provider.py | 140 +++++++++++++++++-
.../providers/test_tokenpony_provider.py | 124 +++++++++++++++-
.../services/test_model_health_service.py | 46 ++++++
.../services/test_model_management_service.py | 136 +++++++++++++++--
.../tools/test_analyze_audio_video_tool.py | 58 +++++++-
7 files changed, 560 insertions(+), 78 deletions(-)
diff --git a/sdk/nexent/core/tools/analyze_audio_tool.py b/sdk/nexent/core/tools/analyze_audio_tool.py
index c7509a6c2..1e5439443 100644
--- a/sdk/nexent/core/tools/analyze_audio_tool.py
+++ b/sdk/nexent/core/tools/analyze_audio_tool.py
@@ -7,7 +7,7 @@
import logging
from io import BytesIO
-from typing import List
+from typing import List, Optional
from jinja2 import StrictUndefined, Template
from pydantic import Field
@@ -28,28 +28,29 @@ class AnalyzeAudioTool(Tool):
"""Tool for understanding and analyzing audio using the video understanding model."""
name = "analyze_audio"
+ skip_forward_signature_validation = True
description = (
"This tool uses the configured video understanding model to understand audio based on your query and then returns an audio analysis result.\n"
- "It is used to understand and analyze multiple audio files, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
+ "It is used to understand and analyze one audio file, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
"HTTP, and HTTPS URLs.\n"
"Use this tool when you want to retrieve information contained in audio and provide the audio URL and your query."
)
description_zh = (
- "使用视频理解模型,根据你的提示词来理解音频,并返回音频分析结果。"
- "可用于理解和分析多个音频文件,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
+ "使用视频理解模型,根据你的问题理解音频,并返回音频分析结果。"
+ "可用于理解和分析一个音频文件,支持 S3 URL(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
)
inputs = {
- "audio_urls_list": {
- "type": "array",
- "description": "List of audio URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.",
- "description_zh": "列表形式输入音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。"
+ "audio_url": {
+ "type": "string",
+ "description": "Audio URL (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.",
+ "description_zh": "音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。",
},
"query": {
"type": "string",
"description": "User's question to guide the audio analysis",
- "description_zh": "用户用于指导音频分析的问题"
- }
+ "description_zh": "用户的问题,用于指导音频分析",
+ },
}
init_param_descriptions = {
@@ -58,9 +59,9 @@ class AnalyzeAudioTool(Tool):
"storage_client": {"description": "Storage client for downloading files"},
"validate_url_access": {
"description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)"
- }
+ },
}
- output_type = "array"
+ output_type = "string"
category = ToolCategory.MULTIMODAL.value
tool_sign = ToolSign.MULTIMODAL_OPERATION.value
@@ -94,10 +95,10 @@ def __init__(
validate_callback = validate_url_access
self.mm = LoadSaveObjectManager(
storage_client=self.storage_client,
- validate_url_access=validate_callback
+ validate_url_access=validate_callback,
)
self.forward = self.mm.load_object(
- input_names=["audio_urls_list"])(self._forward_impl)
+ input_names=["audio_url", "audio_urls_list"])(self._forward_impl)
self.running_prompt_zh = "正在分析音频..."
self.running_prompt_en = "Analyzing audio..."
@@ -114,10 +115,14 @@ def _validate_audio_capable_model(self) -> None:
"Please choose a Qwen3-Omni model for analyze_audio."
)
- def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]:
- """Analyze audio files and return one result per audio input."""
+ def _forward_impl(
+ self,
+ audio_url: Optional[bytes] = None,
+ query: str = "",
+ audio_urls_list: Optional[List[bytes]] = None) -> str:
+ """Analyze an audio file and return the result as a string."""
if self.vlm_model is None:
- error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试"
+ error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试。"
error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again."
error_msg = error_msg_zh if self._is_chinese else error_msg_en
logger.error(error_msg)
@@ -128,12 +133,17 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]:
running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en
self.observer.add_message("", ProcessType.TOOL, running_prompt)
- if audio_urls_list is None:
- raise ValueError("audio_urls cannot be None")
- if not isinstance(audio_urls_list, list):
- raise ValueError("audio_urls must be a list of bytes")
- if not audio_urls_list:
- raise ValueError("audio_urls must contain at least one audio file")
+ if audio_url is not None:
+ audio_items = [audio_url]
+ else:
+ audio_items = audio_urls_list
+
+ if audio_items is None:
+ raise ValueError("audio_url cannot be None")
+ if not isinstance(audio_items, list):
+ raise ValueError("audio_url must be bytes or audio_urls_list must be a list of bytes")
+ if not audio_items:
+ raise ValueError("audio_url must contain an audio file")
language = self.observer.lang if self.observer else "en"
prompts = get_prompt_template(
@@ -143,7 +153,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]:
try:
analysis_results: List[str] = []
- for index, audio_bytes in enumerate(audio_urls_list, start=1):
+ for index, audio_bytes in enumerate(audio_items, start=1):
logger.info(f"Analyzing audio #{index}, query: {query}")
content_type = detect_content_type_from_bytes(audio_bytes)
if not content_type.startswith("audio/"):
@@ -153,7 +163,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]:
response = self.vlm_model.analyze_audio(
audio_input=audio_stream,
system_prompt=system_prompt,
- content_type=content_type
+ content_type=content_type,
)
except Exception as e:
error_msg_zh = f"音频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。"
@@ -163,7 +173,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]:
analysis_results.append(response.content)
- return analysis_results
+ return "\n\n".join(analysis_results)
except Exception as e:
logger.error(f"Error analyzing audio: {str(e)}", exc_info=True)
raise Exception(f"Error analyzing audio: {str(e)}")
diff --git a/sdk/nexent/core/tools/analyze_video_tool.py b/sdk/nexent/core/tools/analyze_video_tool.py
index 3dc033551..e7bf84549 100644
--- a/sdk/nexent/core/tools/analyze_video_tool.py
+++ b/sdk/nexent/core/tools/analyze_video_tool.py
@@ -1,13 +1,13 @@
"""
Analyze Video Tool
-Analyze videos using the configured video understanding model.
-Supports videos from S3, HTTP, and HTTPS URLs.
+Analyze video using the configured video understanding model.
+Supports video from S3, HTTP, and HTTPS URLs.
"""
import logging
from io import BytesIO
-from typing import List
+from typing import List, Optional
from jinja2 import StrictUndefined, Template
from pydantic import Field
@@ -25,31 +25,32 @@
class AnalyzeVideoTool(Tool):
- """Tool for understanding and analyzing videos using the video understanding model."""
+ """Tool for understanding and analyzing video using the video understanding model."""
name = "analyze_video"
+ skip_forward_signature_validation = True
description = (
- "This tool uses the configured video understanding model to understand videos based on your query and then returns a video analysis result.\n"
- "It is used to understand and analyze multiple videos, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
+ "This tool uses the configured video understanding model to understand video based on your query and then returns a video analysis result.\n"
+ "It is used to understand and analyze one video, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), "
"HTTP, and HTTPS URLs.\n"
- "Use this tool when you want to retrieve information contained in a video and provide the video's URL and your query."
+ "Use this tool when you want to retrieve information contained in a video and provide the video URL and your query."
)
description_zh = (
- "使用视频理解模型,根据你的提示词来理解视频,并返回视频分析结果。"
- "可用于理解和分析多个视频,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
+ "使用视频理解模型,根据你的问题理解视频,并返回视频分析结果。"
+ "可用于理解和分析一个视频,支持 S3 URL(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。"
)
inputs = {
- "video_urls_list": {
- "type": "array",
- "description": "List of video URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.",
- "description_zh": "列表形式输入视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。"
+ "video_url": {
+ "type": "string",
+ "description": "Video URL (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.",
+ "description_zh": "视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。",
},
"query": {
"type": "string",
"description": "User's question to guide the video analysis",
- "description_zh": "用户用于指导视频分析的问题"
- }
+ "description_zh": "用户的问题,用于指导视频分析",
+ },
}
init_param_descriptions = {
@@ -58,9 +59,9 @@ class AnalyzeVideoTool(Tool):
"storage_client": {"description": "Storage client for downloading files"},
"validate_url_access": {
"description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)"
- }
+ },
}
- output_type = "array"
+ output_type = "string"
category = ToolCategory.MULTIMODAL.value
tool_sign = ToolSign.MULTIMODAL_OPERATION.value
@@ -94,18 +95,22 @@ def __init__(
validate_callback = validate_url_access
self.mm = LoadSaveObjectManager(
storage_client=self.storage_client,
- validate_url_access=validate_callback
+ validate_url_access=validate_callback,
)
self.forward = self.mm.load_object(
- input_names=["video_urls_list"])(self._forward_impl)
+ input_names=["video_url", "video_urls_list"])(self._forward_impl)
self.running_prompt_zh = "正在分析视频..."
self.running_prompt_en = "Analyzing video..."
- def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]:
- """Analyze videos and return one result per video input."""
+ def _forward_impl(
+ self,
+ video_url: Optional[bytes] = None,
+ query: str = "",
+ video_urls_list: Optional[List[bytes]] = None) -> str:
+ """Analyze a video and return the result as a string."""
if self.vlm_model is None:
- error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试"
+ error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试。"
error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again."
error_msg = error_msg_zh if self._is_chinese else error_msg_en
logger.error(error_msg)
@@ -115,12 +120,17 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]:
running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en
self.observer.add_message("", ProcessType.TOOL, running_prompt)
- if video_urls_list is None:
- raise ValueError("video_urls cannot be None")
- if not isinstance(video_urls_list, list):
- raise ValueError("video_urls must be a list of bytes")
- if not video_urls_list:
- raise ValueError("video_urls must contain at least one video")
+ if video_url is not None:
+ video_items = [video_url]
+ else:
+ video_items = video_urls_list
+
+ if video_items is None:
+ raise ValueError("video_url cannot be None")
+ if not isinstance(video_items, list):
+ raise ValueError("video_url must be bytes or video_urls_list must be a list of bytes")
+ if not video_items:
+ raise ValueError("video_url must contain a video")
language = self.observer.lang if self.observer else "en"
prompts = get_prompt_template(
@@ -130,7 +140,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]:
try:
analysis_results: List[str] = []
- for index, video_bytes in enumerate(video_urls_list, start=1):
+ for index, video_bytes in enumerate(video_items, start=1):
logger.info(f"Analyzing video #{index}, query: {query}")
content_type = detect_content_type_from_bytes(video_bytes)
if not content_type.startswith("video/"):
@@ -140,7 +150,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]:
response = self.vlm_model.analyze_video(
video_input=video_stream,
system_prompt=system_prompt,
- content_type=content_type
+ content_type=content_type,
)
except Exception as e:
error_msg_zh = f"视频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。"
@@ -150,7 +160,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]:
analysis_results.append(response.content)
- return analysis_results
+ return "\n\n".join(analysis_results)
except Exception as e:
logger.error(f"Error analyzing video: {str(e)}", exc_info=True)
raise Exception(f"Error analyzing video: {str(e)}")
diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py
index 30229677a..5c6267040 100644
--- a/test/backend/services/providers/test_dashscope_provider.py
+++ b/test/backend/services/providers/test_dashscope_provider.py
@@ -141,11 +141,27 @@ async def test_get_models_vlm_success(self, mocker: MockFixture):
"models": [
{
"model": "qwen-vl-plus",
- "description": "Vision language model",
+ "description": "Vision language model for image understanding",
"inference_metadata": {
"request_modality": ["Image", "Text"],
"response_modality": ["Text"]
}
+ },
+ {
+ "model": "qwen3.6-27b",
+ "description": "Qwen 3.6 multimodal model",
+ "inference_metadata": {
+ "request_modality": ["Text"],
+ "response_modality": ["Text"]
+ }
+ },
+ {
+ "model": "qwen-vl-max",
+ "description": "Qwen VL max model",
+ "inference_metadata": {
+ "request_modality": ["Image", "Text"],
+ "response_modality": ["Text", "Image"]
+ }
}
]
}
@@ -167,11 +183,129 @@ async def test_get_models_vlm_success(self, mocker: MockFixture):
result = await provider.get_models(provider_config)
+ assert [model["id"] for model in result] == ["qwen-vl-plus", "qwen3.6-27b", "qwen-vl-max"]
+ assert all(model["model_type"] == "vlm" for model in result)
+ assert all(model["model_tag"] == "chat" for model in result)
+
+ @pytest.mark.asyncio
+ async def test_get_models_vlm2_only_returns_image_generation_models(self, mocker: MockFixture):
+ """Image generation slot only returns image-generation multimodal models."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "output": {
+ "models": [
+ {
+ "model": "qwen-vl-plus",
+ "description": "Vision language model",
+ "inference_metadata": {
+ "request_modality": ["Image", "Text"],
+ "response_modality": ["Text"]
+ }
+ },
+ {
+ "model": "qwen-image-max",
+ "description": "Image generation model",
+ "inference_metadata": {
+ "request_modality": ["Text"],
+ "response_modality": ["Image"]
+ }
+ },
+ {
+ "model": "qwen-vl-max",
+ "description": "Qwen VL max model",
+ "inference_metadata": {
+ "request_modality": ["Image", "Text"],
+ "response_modality": ["Text", "Image"]
+ }
+ },
+ {
+ "model": "qwen-plus",
+ "description": "Text generation model",
+ "inference_metadata": {
+ "request_modality": ["Text"],
+ "response_modality": ["Text"]
+ }
+ }
+ ]
+ }
+ }
+ mock_response.raise_for_status = MagicMock()
+
+ self._setup_mock_client(mocker, mock_response)
+
+ provider = DashScopeModelProvider()
+ provider_config = {
+ "model_type": "vlm2",
+ "api_key": "test-api-key"
+ }
+
+ result = await provider.get_models(provider_config)
+
assert len(result) == 1
- assert result[0]["id"] == "qwen-vl-plus"
- assert result[0]["model_type"] == "vlm"
+ assert result[0]["id"] == "qwen-image-max"
+ assert result[0]["model_type"] == "vlm2"
assert result[0]["model_tag"] == "chat"
+ @pytest.mark.asyncio
+ async def test_get_models_vlm3_only_returns_video_understanding_models(self, mocker: MockFixture):
+ """Video understanding slot excludes image generation and text-only models."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "output": {
+ "models": [
+ {
+ "model": "qwen-image-max",
+ "description": "Image generation model",
+ "inference_metadata": {
+ "request_modality": ["Text"],
+ "response_modality": ["Image"]
+ }
+ },
+ {
+ "model": "qwen-omni-turbo",
+ "description": "Video understanding model",
+ "inference_metadata": {
+ "request_modality": ["Video", "Text"],
+ "response_modality": ["Text"]
+ }
+ },
+ {
+ "model": "qwen3-omni-30b-a3b-instruct",
+ "description": "Omni multimodal model",
+ "inference_metadata": {
+ "request_modality": ["Text"],
+ "response_modality": ["Text"]
+ }
+ },
+ {
+ "model": "qwen-plus",
+ "description": "Text generation model",
+ "inference_metadata": {
+ "request_modality": ["Text"],
+ "response_modality": ["Text"]
+ }
+ }
+ ]
+ }
+ }
+ mock_response.raise_for_status = MagicMock()
+
+ self._setup_mock_client(mocker, mock_response)
+
+ provider = DashScopeModelProvider()
+ provider_config = {
+ "model_type": "vlm3",
+ "api_key": "test-api-key"
+ }
+
+ result = await provider.get_models(provider_config)
+
+ assert [model["id"] for model in result] == ["qwen-omni-turbo", "qwen3-omni-30b-a3b-instruct"]
+ assert all(model["model_type"] == "vlm3" for model in result)
+ assert all(model["model_tag"] == "chat" for model in result)
+
@pytest.mark.asyncio
async def test_get_models_rerank_success(self, mocker: MockFixture):
"""Test successful model retrieval for rerank models."""
diff --git a/test/backend/services/providers/test_tokenpony_provider.py b/test/backend/services/providers/test_tokenpony_provider.py
index e93d8ba7b..58e514dbb 100644
--- a/test/backend/services/providers/test_tokenpony_provider.py
+++ b/test/backend/services/providers/test_tokenpony_provider.py
@@ -126,6 +126,16 @@ async def test_get_models_vlm_success(self, mocker: MockFixture):
"id": "qwen-vl-plus",
"object": "model",
"owned_by": "qwen"
+ },
+ {
+ "id": "qwen3.6-27b",
+ "object": "model",
+ "owned_by": "qwen"
+ },
+ {
+ "id": "qwen-vl-max",
+ "object": "model",
+ "owned_by": "qwen"
}
]
}
@@ -155,11 +165,121 @@ async def test_get_models_vlm_success(self, mocker: MockFixture):
result = await provider.get_models(provider_config)
+ assert [model["id"] for model in result] == ["qwen-vl-plus", "qwen3.6-27b", "qwen-vl-max"]
+ assert all(model["model_type"] == "vlm" for model in result)
+ assert all(model["model_tag"] == "chat" for model in result)
+
+ @pytest.mark.asyncio
+ async def test_get_models_vlm2_only_returns_image_generation_models(self, mocker: MockFixture):
+ """Image generation slot only returns image-generation multimodal models."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "data": [
+ {
+ "id": "qwen-vl-plus",
+ "object": "model",
+ "owned_by": "qwen"
+ },
+ {
+ "id": "flux-image-pro",
+ "object": "model",
+ "owned_by": "flux"
+ },
+ {
+ "id": "qwen-vl-max",
+ "object": "model",
+ "owned_by": "qwen"
+ },
+ {
+ "id": "qwen-plus",
+ "object": "model",
+ "owned_by": "qwen"
+ }
+ ]
+ }
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_response
+
+ mock_cm = MagicMock()
+ mock_cm.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_cm.__aexit__ = AsyncMock(return_value=None)
+
+ mocker.patch(
+ "backend.services.providers.tokenpony_provider.httpx.AsyncClient",
+ return_value=mock_cm
+ )
+
+ provider = TokenPonyModelProvider()
+ provider_config = {
+ "model_type": "vlm2",
+ "api_key": "test-api-key"
+ }
+
+ result = await provider.get_models(provider_config)
+
assert len(result) == 1
- assert result[0]["id"] == "qwen-vl-plus"
- assert result[0]["model_type"] == "vlm"
+ assert result[0]["id"] == "flux-image-pro"
+ assert result[0]["model_type"] == "vlm2"
assert result[0]["model_tag"] == "chat"
+ @pytest.mark.asyncio
+ async def test_get_models_vlm3_only_returns_video_understanding_models(self, mocker: MockFixture):
+ """Video understanding slot excludes image generation and text-only models."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "data": [
+ {
+ "id": "flux-image-pro",
+ "object": "model",
+ "owned_by": "flux"
+ },
+ {
+ "id": "qwen-omni-video",
+ "object": "model",
+ "owned_by": "qwen"
+ },
+ {
+ "id": "qwen3-omni-30b-a3b-instruct",
+ "object": "model",
+ "owned_by": "qwen"
+ },
+ {
+ "id": "qwen-plus",
+ "object": "model",
+ "owned_by": "qwen"
+ }
+ ]
+ }
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.get.return_value = mock_response
+
+ mock_cm = MagicMock()
+ mock_cm.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_cm.__aexit__ = AsyncMock(return_value=None)
+
+ mocker.patch(
+ "backend.services.providers.tokenpony_provider.httpx.AsyncClient",
+ return_value=mock_cm
+ )
+
+ provider = TokenPonyModelProvider()
+ provider_config = {
+ "model_type": "vlm3",
+ "api_key": "test-api-key"
+ }
+
+ result = await provider.get_models(provider_config)
+
+ assert [model["id"] for model in result] == ["qwen-omni-video", "qwen3-omni-30b-a3b-instruct"]
+ assert all(model["model_type"] == "vlm3" for model in result)
+ assert all(model["model_tag"] == "chat" for model in result)
+
@pytest.mark.asyncio
async def test_get_models_rerank_success(self, mocker: MockFixture):
"""Test successful model retrieval for rerank models."""
diff --git a/test/backend/services/test_model_health_service.py b/test/backend/services/test_model_health_service.py
index 139ded1a2..559677527 100644
--- a/test/backend/services/test_model_health_service.py
+++ b/test/backend/services/test_model_health_service.py
@@ -24,6 +24,7 @@ def __getattr__(cls, key):
sys.modules['utils'] = MockModule()
sys.modules['utils.auth_utils'] = MockModule()
sys.modules['utils.config_utils'] = MockModule()
+sys.modules['utils.memory_utils'] = MockModule()
sys.modules['utils.model_name_utils'] = MockModule()
sys.modules['consts'] = MockModule()
consts_const_module = MockModule()
@@ -221,6 +222,51 @@ async def test_perform_connectivity_check_vlm():
mock_model_instance.check_connectivity.assert_called_once()
+@pytest.mark.asyncio
+async def test_perform_connectivity_check_dashscope_multimodal_uses_provider_catalog():
+ model_provider_service = types.ModuleType("services.model_provider_service")
+ model_provider_service.get_provider_models = mock.AsyncMock(return_value=[
+ {"id": "qwen-image-max", "model_type": "vlm2"},
+ ])
+
+ with mock.patch.dict(sys.modules, {"services.model_provider_service": model_provider_service}), \
+ mock.patch("backend.services.model_health_service.OpenAIVLModel") as mock_model:
+ result = await _perform_connectivity_check(
+ "qwen-image-max",
+ "vlm2",
+ "https://dashscope.aliyuncs.com/compatible-mode/v1/",
+ "test-key",
+ model_factory="dashscope",
+ )
+
+ assert result is True
+ model_provider_service.get_provider_models.assert_awaited_once_with({
+ "provider": "dashscope",
+ "model_type": "vlm2",
+ "api_key": "test-key",
+ })
+ mock_model.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_perform_connectivity_check_tokenpony_multimodal_catalog_error_returns_false():
+ model_provider_service = types.ModuleType("services.model_provider_service")
+ model_provider_service.get_provider_models = mock.AsyncMock(return_value=[
+ {"_error": "authentication_failed", "_message": "Invalid API key"},
+ ])
+
+ with mock.patch.dict(sys.modules, {"services.model_provider_service": model_provider_service}):
+ result = await _perform_connectivity_check(
+ "qwen-vl-plus",
+ "vlm3",
+ "https://api.tokenpony.cn/v1/",
+ "bad-key",
+ model_factory="tokenpony",
+ )
+
+ assert result is False
+
+
@pytest.mark.asyncio
async def test_perform_connectivity_check_stt():
# Setup
diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py
index 83a070fe0..be9c5e406 100644
--- a/test/backend/services/test_model_management_service.py
+++ b/test/backend/services/test_model_management_service.py
@@ -83,7 +83,22 @@ def get_value(status):
return status or _ModelConnectStatusEnum.NOT_DETECTED.value
+class _ToolValidateRequest:
+ def __init__(self, **kwargs):
+ self.__dict__.update(kwargs)
+
+
+class _ProcessParams:
+ def __init__(self, **kwargs):
+ self.__dict__.update(kwargs)
+
+ def model_dump(self, *args, **kwargs):
+ return dict(self.__dict__)
+
+
consts_model_mod.ModelConnectStatusEnum = _ModelConnectStatusEnum
+consts_model_mod.ToolValidateRequest = _ToolValidateRequest
+consts_model_mod.ProcessParams = _ProcessParams
sys.modules["consts.model"] = consts_model_mod
if "consts" not in sys.modules:
sys.modules["consts"] = types.ModuleType("consts")
@@ -93,6 +108,23 @@ def get_value(status):
consts_const_mod.LOCALHOST_IP = "127.0.0.1"
consts_const_mod.LOCALHOST_NAME = "localhost"
consts_const_mod.DOCKER_INTERNAL_HOST = "host.docker.internal"
+consts_const_mod.DATA_PROCESS_SERVICE = "http://data-process"
+consts_const_mod.FILE_PREVIEW_SIZE_LIMIT = 100 * 1024 * 1024
+consts_const_mod.MAX_CONCURRENT_UPLOADS = 5
+consts_const_mod.OFFICE_MIME_TYPES = []
+consts_const_mod.UPLOAD_FOLDER = "uploads"
+consts_const_mod.LOCAL_MCP_SERVER = "http://local-mcp"
+consts_const_mod.MCP_MANAGEMENT_API = "http://mcp-management"
+consts_const_mod.LIBREOFFICE_PROFILE_DIR = "libreoffice-profile"
+consts_const_mod.DEFAULT_TENANT_ID = "tenant_id"
+consts_const_mod.DEFAULT_USER_ID = "user_id"
+consts_const_mod.IS_SPEED_MODE = False
+consts_const_mod.SUPABASE_JWT_SECRET = "test-secret"
+consts_const_mod.SUPABASE_URL = "http://supabase"
+consts_const_mod.SUPABASE_KEY = "supabase-key"
+consts_const_mod.SERVICE_ROLE_KEY = "service-role-key"
+consts_const_mod.DEBUG_JWT_EXPIRE_SECONDS = 3600
+consts_const_mod.LANGUAGE = "zh"
# Fields required by utils.memory_utils and services.vectordatabase_service
consts_const_mod.MODEL_CONFIG_MAPPING = {
"llm": "LLM_ID", "embedding": "EMBEDDING_ID"}
@@ -193,9 +225,23 @@ def _sort_models_by_id(model_list):
utils_name_mod.sort_models_by_id = _sort_models_by_id
sys.modules["utils.model_name_utils"] = utils_name_mod
+# Stub utils.file_management_utils so file_management_service can be imported
+# by other tests in the same pytest process without pulling auth/database deps.
+utils_file_mgmt_mod = types.ModuleType("utils.file_management_utils")
+
+
+async def _save_upload_file(*args, **kwargs):
+ return None
+
+
+utils_file_mgmt_mod.save_upload_file = _save_upload_file
+sys.modules["utils.file_management_utils"] = utils_file_mgmt_mod
+
# Stub database.model_management_db to avoid importing heavy DB client
database_mod = types.ModuleType("database")
+database_mod.__path__ = []
db_mm_mod = types.ModuleType("database.model_management_db")
+db_attachment_mod = types.ModuleType("database.attachment_db")
def _noop(*args, **kwargs):
@@ -245,6 +291,22 @@ def _get_model_by_model_id(model_id: int, tenant_id: str):
db_mm_mod.update_model_record = _noop
sys.modules["database"] = database_mod
sys.modules["database.model_management_db"] = db_mm_mod
+for _attachment_func in [
+ "copy_file",
+ "delete_file",
+ "file_exists",
+ "get_content_type",
+ "get_file_range",
+ "get_file_size_from_minio",
+ "get_file_stream",
+ "get_file_stream_raw",
+ "get_file_url",
+ "list_files",
+ "upload_fileobj",
+]:
+ setattr(db_attachment_mod, _attachment_func, _noop)
+sys.modules["database.attachment_db"] = db_attachment_mod
+setattr(database_mod, "attachment_db", db_attachment_mod)
# Stub database.tenant_config_db required by utils.config_utils
db_tenant_cfg_mod = types.ModuleType("database.tenant_config_db")
@@ -286,10 +348,15 @@ def _update_config_by_tenant_config_id(*args, **kwargs):
services_vdb_mod = types.ModuleType("services.vectordatabase_service")
+class _ElasticSearchService:
+ pass
+
+
def _get_vector_db_core():
return object()
+services_vdb_mod.ElasticSearchService = _ElasticSearchService
services_vdb_mod.get_vector_db_core = _get_vector_db_core
sys.modules["services.vectordatabase_service"] = services_vdb_mod
@@ -340,7 +407,7 @@ def import_svc():
async def test_create_model_for_tenant_success_llm():
svc = import_svc()
- with mock.patch.object(svc, "get_model_by_display_name", return_value=None) as mock_get_by_display, \
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=[]) as mock_get_by_display, \
mock.patch.object(svc, "create_model_record") as mock_create, \
mock.patch.object(svc, "split_repo_name", return_value=("huggingface", "llama")):
@@ -367,7 +434,7 @@ async def test_create_model_for_tenant_open_router_disables_ssl():
"""When base_url contains 'open/router' ssl_verify should be set to False and model_factory to 'modelengine'."""
svc = import_svc()
- with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \
mock.patch.object(svc, "create_model_record") as mock_create, \
mock.patch.object(svc, "split_repo_name", return_value=("modelengine", "m")):
@@ -394,7 +461,7 @@ async def test_create_model_for_tenant_open_router_disables_ssl():
async def test_create_model_for_tenant_conflict_raises():
svc = import_svc()
- with mock.patch.object(svc, "get_model_by_display_name", return_value={"model_id": "exists"}):
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=[{"model_id": "exists", "model_type": "llm"}]):
user_id = "u1"
tenant_id = "t1"
model_data = {
@@ -415,7 +482,7 @@ async def test_create_model_for_tenant_display_name_conflict_valueerror():
svc = import_svc()
existing_model = {"model_id": 1, "display_name": "existing_name"}
- with mock.patch.object(svc, "get_model_by_display_name", return_value=existing_model):
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=[existing_model]):
user_id = "u1"
tenant_id = "t1"
model_data = {
@@ -432,11 +499,58 @@ async def test_create_model_for_tenant_display_name_conflict_valueerror():
assert "existing_name" in str(exc.value)
+@pytest.mark.asyncio
+async def test_create_model_for_tenant_allows_same_display_name_across_multimodal_slots():
+ """Image understanding, image generation, and video understanding are separate slots."""
+ svc = import_svc()
+
+ existing_models = [
+ {"model_id": 1, "display_name": "Qwen3.6-27B", "model_type": "vlm"},
+ {"model_id": 2, "display_name": "Qwen3.6-27B", "model_type": "vlm3"},
+ ]
+
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=existing_models), \
+ mock.patch.object(svc, "create_model_record") as mock_create, \
+ mock.patch.object(svc, "split_repo_name", return_value=("Qwen", "Qwen3.6-27B")):
+
+ model_data = {
+ "model_name": "Qwen/Qwen3.6-27B",
+ "display_name": "Qwen3.6-27B",
+ "base_url": "https://api.example.com/v1",
+ "model_type": "vlm2",
+ }
+
+ await svc.create_model_for_tenant("u1", "t1", model_data)
+
+ mock_create.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_create_model_for_tenant_blocks_duplicate_within_same_multimodal_slot():
+ svc = import_svc()
+
+ with mock.patch.object(
+ svc,
+ "get_models_by_display_name",
+ return_value=[{"model_id": 1, "display_name": "Qwen3.6-27B", "model_type": "vlm"}],
+ ):
+ model_data = {
+ "model_name": "Qwen/Qwen3.6-27B",
+ "display_name": "Qwen3.6-27B",
+ "base_url": "https://api.example.com/v1",
+ "model_type": "vlm",
+ }
+
+ with pytest.raises(Exception) as exc:
+ await svc.create_model_for_tenant("u1", "t1", model_data)
+ assert "already in use" in str(exc.value)
+
+
@pytest.mark.asyncio
async def test_create_model_for_tenant_multi_embedding_creates_two_records():
svc = import_svc()
- with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \
mock.patch.object(svc, "create_model_record") as mock_create, \
mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")):
@@ -458,7 +572,7 @@ async def test_create_model_for_tenant_multi_embedding_creates_two_records():
async def test_create_model_for_tenant_embedding_sets_dimension():
svc = import_svc()
- with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \
mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=1536)) as mock_dim, \
mock.patch.object(svc, "create_model_record") as mock_create, \
mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-ada-002")):
@@ -484,7 +598,7 @@ async def test_create_model_for_tenant_embedding_sets_default_chunk_batch():
"""chunk_batch defaults to 10 when not provided for embedding models."""
svc = import_svc()
- with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \
mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \
mock.patch.object(svc, "create_model_record") as mock_create, \
mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-3-small")):
@@ -513,7 +627,7 @@ async def test_create_model_for_tenant_multi_embedding_sets_default_chunk_batch(
"""chunk_batch defaults to 10 when not provided for multi_embedding models (covers line 79)."""
svc = import_svc()
- with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
+ with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \
mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \
mock.patch.object(svc, "create_model_record") as mock_create, \
mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")):
@@ -591,7 +705,7 @@ async def test_batch_create_models_for_tenant_dashscope_provider():
mock.patch.object(svc, "delete_model_record"), \
mock.patch.object(svc, "split_repo_name", return_value=("qwen", "qwen-turbo")), \
mock.patch.object(svc, "add_repo_to_name", return_value="qwen/qwen-turbo"), \
- mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
+ mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \
mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \
mock.patch.object(svc, "create_model_record", return_value=True):
@@ -617,7 +731,7 @@ async def test_batch_create_models_for_tenant_tokenpony_provider():
mock.patch.object(svc, "delete_model_record"), \
mock.patch.object(svc, "split_repo_name", return_value=("gpt", "gpt-4o")), \
mock.patch.object(svc, "add_repo_to_name", return_value="gpt/gpt-4o"), \
- mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
+ mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \
mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 2})), \
mock.patch.object(svc, "create_model_record", return_value=True):
@@ -650,7 +764,7 @@ async def test_batch_create_models_for_tenant_other_provider():
mock.patch.object(svc, "delete_model_record"), \
mock.patch.object(svc, "split_repo_name", return_value=("openai", "gpt-4")), \
mock.patch.object(svc, "add_repo_to_name", return_value="openai/gpt-4"), \
- mock.patch.object(svc, "get_model_by_display_name", return_value=None), \
+ mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \
mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \
mock.patch.object(svc, "create_model_record", return_value=True):
diff --git a/test/sdk/core/tools/test_analyze_audio_video_tool.py b/test/sdk/core/tools/test_analyze_audio_video_tool.py
index 94401b61d..7369ddfb2 100644
--- a/test/sdk/core/tools/test_analyze_audio_video_tool.py
+++ b/test/sdk/core/tools/test_analyze_audio_video_tool.py
@@ -44,9 +44,9 @@ def _fake_get_prompt(template_type, language=None, **_):
storage_client=mock_storage_client,
)
- result = tool._forward_impl([b"ID3audio-bytes"], "what happened?")
+ result = tool._forward_impl(audio_url=b"ID3audio-bytes", query="what happened?")
- assert result == ["audio result"]
+ assert result == "audio result"
assert calls == [("analyze_audio", "en")]
mock_vlm_model.analyze_audio.assert_called_once()
call_kwargs = mock_vlm_model.analyze_audio.call_args.kwargs
@@ -55,6 +55,30 @@ def _fake_get_prompt(template_type, language=None, **_):
observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing audio...")
+def test_analyze_audio_schema_uses_single_url():
+ assert "audio_url" in AnalyzeAudioTool.inputs
+ assert "audio_urls_list" not in AnalyzeAudioTool.inputs
+ assert AnalyzeAudioTool.output_type == "string"
+
+
+def test_analyze_audio_accepts_legacy_url_list(observer_en, mock_vlm_model, mock_storage_client, monkeypatch):
+ monkeypatch.setattr(
+ analyze_audio_tool,
+ "get_prompt_template",
+ lambda template_type, language=None, **_: {"system_prompt": "Analyze audio for {{ query }}"},
+ )
+ mock_vlm_model.analyze_audio.return_value = SimpleNamespace(content="audio result")
+ tool = AnalyzeAudioTool(
+ observer=observer_en,
+ vlm_model=mock_vlm_model,
+ storage_client=mock_storage_client,
+ )
+
+ result = tool._forward_impl(audio_urls_list=[b"ID3audio-bytes"], query="what happened?")
+
+ assert result == "audio result"
+
+
def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_storage_client):
vlm_model = SimpleNamespace(
model_id="Qwen/Qwen3-VL-32B-Instruct",
@@ -67,7 +91,7 @@ def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_stor
)
with pytest.raises(ValueError) as exc_info:
- tool._forward_impl([b"ID3audio-bytes"], "what happened?")
+ tool._forward_impl(audio_url=b"ID3audio-bytes", query="what happened?")
assert "Please choose a Qwen3-Omni model" in str(exc_info.value)
@@ -87,9 +111,9 @@ def _fake_get_prompt(template_type, language=None, **_):
storage_client=mock_storage_client,
)
- result = tool._forward_impl([b"\x00\x00\x00\x18ftypmp42video-bytes"], "what happened?")
+ result = tool._forward_impl(video_url=b"\x00\x00\x00\x18ftypmp42video-bytes", query="what happened?")
- assert result == ["video result"]
+ assert result == "video result"
assert calls == [("analyze_video", "en")]
mock_vlm_model.analyze_video.assert_called_once()
call_kwargs = mock_vlm_model.analyze_video.call_args.kwargs
@@ -98,6 +122,30 @@ def _fake_get_prompt(template_type, language=None, **_):
observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing video...")
+def test_analyze_video_schema_uses_single_url():
+ assert "video_url" in AnalyzeVideoTool.inputs
+ assert "video_urls_list" not in AnalyzeVideoTool.inputs
+ assert AnalyzeVideoTool.output_type == "string"
+
+
+def test_analyze_video_accepts_legacy_url_list(observer_en, mock_vlm_model, mock_storage_client, monkeypatch):
+ monkeypatch.setattr(
+ analyze_video_tool,
+ "get_prompt_template",
+ lambda template_type, language=None, **_: {"system_prompt": "Analyze video for {{ query }}"},
+ )
+ mock_vlm_model.analyze_video.return_value = SimpleNamespace(content="video result")
+ tool = AnalyzeVideoTool(
+ observer=observer_en,
+ vlm_model=mock_vlm_model,
+ storage_client=mock_storage_client,
+ )
+
+ result = tool._forward_impl(video_urls_list=[b"\x00\x00\x00\x18ftypmp42video-bytes"], query="what happened?")
+
+ assert result == "video result"
+
+
@pytest.mark.parametrize(
"tool_class,input_name,error_text",
[
From 4aa9fa98af41d1270c8dfacacef30e611906f674 Mon Sep 17 00:00:00 2001
From: 827dls <1670704430@qq.com>
Date: Thu, 28 May 2026 00:56:01 +0800
Subject: [PATCH 7/7] fix test 2
---
.../test_tool_configuration_service.py | 37 +++++++++++++++++++
1 file changed, 37 insertions(+)
diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py
index ed6879c6e..29d2f75f6 100644
--- a/test/backend/services/test_tool_configuration_service.py
+++ b/test/backend/services/test_tool_configuration_service.py
@@ -3611,6 +3611,7 @@ def test_get_llm_model_success(self, mock_tenant_config, mock_get_model_name, mo
api_key="test_api_key",
max_context_tokens=4096,
ssl_verify=True,
+ timeout_seconds=None,
)
@patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"})
@@ -3650,6 +3651,42 @@ def test_get_llm_model_with_missing_config_values(self, mock_tenant_config, mock
call_kwargs = mock_openai_model.call_args[1]
assert call_kwargs["api_key"] is None
assert call_kwargs["max_context_tokens"] is None
+ assert call_kwargs["timeout_seconds"] is None
+
+ @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"})
+ @patch('backend.services.file_management_service.MessageObserver')
+ @patch('backend.services.file_management_service.OpenAILongContextModel')
+ @patch('backend.services.file_management_service.get_model_name_from_config')
+ @patch('backend.services.file_management_service.tenant_config_manager')
+ def test_get_llm_model_with_timeout_seconds(self, mock_tenant_config, mock_get_model_name, mock_openai_model, mock_message_observer):
+ """Test get_llm_model passes configured timeout_seconds."""
+ from backend.services.file_management_service import get_llm_model
+
+ mock_config = {
+ "base_url": "http://api.example.com",
+ "api_key": "test_api_key",
+ "max_tokens": 4096,
+ "timeout_seconds": 30,
+ }
+ mock_tenant_config.get_model_config.return_value = mock_config
+ mock_get_model_name.return_value = "gpt-4"
+ mock_observer_instance = Mock()
+ mock_message_observer.return_value = mock_observer_instance
+ mock_model_instance = Mock()
+ mock_openai_model.return_value = mock_model_instance
+
+ result = get_llm_model("tenant123")
+
+ assert result == mock_model_instance
+ mock_openai_model.assert_called_once_with(
+ observer=mock_observer_instance,
+ model_id="gpt-4",
+ api_base="http://api.example.com",
+ api_key="test_api_key",
+ max_context_tokens=4096,
+ ssl_verify=True,
+ timeout_seconds=30,
+ )
@patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"})
@patch('backend.services.file_management_service.MessageObserver')