diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 54d7b3bc8..109f92ad0 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -15,6 +15,11 @@ logger = logging.getLogger("model_health_service") +DASHSCOPE_MODEL_FACTORY = "dashscope" +TOKENPONY_MODEL_FACTORY = "tokenpony" +PROVIDER_CATALOG_HEALTHCHECK_FACTORIES = {DASHSCOPE_MODEL_FACTORY, TOKENPONY_MODEL_FACTORY} +PROVIDER_CATALOG_HEALTHCHECK_TYPES = {"vlm", "vlm2", "vlm3"} + def _mask_secret(value: Optional[str]) -> str: """Mask a secret value, showing only first and last 4 characters.""" @@ -64,6 +69,31 @@ async def _embedding_dimension_check( raise ValueError(f"Unsupported model type: {model_type}") +async def _provider_catalog_connectivity_check( + model_name: str, + model_type: str, + model_api_key: str, + model_factory: Optional[str], +) -> bool: + """Validate provider-managed multimodal models through their model catalog.""" + provider = (model_factory or "").lower() + if provider not in PROVIDER_CATALOG_HEALTHCHECK_FACTORIES: + return False + + from services.model_provider_service import get_provider_models + + model_list = await get_provider_models({ + "provider": provider, + "model_type": model_type, + "api_key": model_api_key, + }) + if not model_list or any(model.get("_error") for model in model_list): + return False + + expected_model_id = model_name.lower() + return any(str(model.get("id", "")).lower() == expected_model_id for model in model_list) + + async def _perform_connectivity_check( model_name: str, model_type: str, @@ -135,6 +165,18 @@ async def _perform_connectivity_check( ) connectivity = await rerank_model.connectivity_check() elif model_type in ("vlm", "vlm2", "vlm3"): + if ( + model_type in PROVIDER_CATALOG_HEALTHCHECK_TYPES + and (model_factory or "").lower() in PROVIDER_CATALOG_HEALTHCHECK_FACTORIES + ): + connectivity = await _provider_catalog_connectivity_check( + model_name=model_name, + model_type=model_type, + model_api_key=model_api_key, + model_factory=model_factory, + ) + return connectivity + observer = MessageObserver() set_monitoring_operation("connectivity_check", display_name=display_name) diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 8f6d191fd..9f032728a 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -8,7 +8,6 @@ from database.model_management_db import ( create_model_record, delete_model_record, - get_model_by_display_name, get_model_by_name_factory, get_models_by_display_name, get_model_records, @@ -32,6 +31,23 @@ logger = logging.getLogger("model_management_service") +INDEPENDENT_MULTIMODAL_MODEL_TYPES = {"vlm", "vlm2", "vlm3"} + + +def _has_display_name_conflict(existing_models: List[Dict[str, Any]], model_type: Optional[str]) -> bool: + """Allow the three multimodal slots to share display names across slots.""" + if not existing_models: + return False + + if model_type in INDEPENDENT_MULTIMODAL_MODEL_TYPES: + return any( + existing.get("model_type") == model_type + or existing.get("model_type") not in INDEPENDENT_MULTIMODAL_MODEL_TYPES + for existing in existing_models + ) + + return True + async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict[str, Any]): """Create a single model record for the given tenant. @@ -77,9 +93,9 @@ async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict # Check display name conflict scoped by tenant if model_data.get("display_name"): - existing_model_by_display = get_model_by_display_name( + existing_models_by_display = get_models_by_display_name( model_data["display_name"], tenant_id) - if existing_model_by_display: + if _has_display_name_conflict(existing_models_by_display, model_data.get("model_type")): logging.error( f"Name {model_data['display_name']} is already in use, please choose another display name") raise ValueError( diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index 69096fb15..497dcfe99 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -6,6 +6,75 @@ from services.providers.base import AbstractModelProvider, _classify_provider_error +DASHSCOPE_IMAGE_GENERATION_KEYWORDS = ( + "image", + "wanx", + "aitryon", + "tryon", + "flux", + "stable-diffusion", + "sdxl", +) +DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS = ( + "qwen-vl", + "qwen2-vl", + "qwen2.5-vl", + "qwen3-vl", + "qwen3.5-vl", + "qwen3.6-vl", + "-vl", + "vl-", + "vision", + "visual", + "ocr", + "qwen3.6", + "qwen-3.6", +) +DASHSCOPE_VIDEO_UNDERSTANDING_KEYWORDS = ("omni", "video-understanding", "video-ocr") + + +def _modality_set(value) -> set: + if not value: + return set() + if isinstance(value, str): + return {value.lower()} + return {str(item).lower() for item in value} + + +def _has_keyword(text: str, keywords: tuple) -> bool: + return any(keyword in text for keyword in keywords) + + +def _is_dashscope_explicit_image_understanding_model(model_id: str) -> bool: + return _has_keyword(model_id, DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS) + + +def _is_dashscope_image_generation_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool: + if _is_dashscope_explicit_image_understanding_model(model_id): + return False + return "image" in res_mods or _has_keyword(model_id, DASHSCOPE_IMAGE_GENERATION_KEYWORDS) + + +def _is_dashscope_video_understanding_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool: + searchable_text = f"{model_id} {desc.lower()}" + if "video" in req_mods and "text" in res_mods: + return True + return _has_keyword(searchable_text, DASHSCOPE_VIDEO_UNDERSTANDING_KEYWORDS) + + +def _is_dashscope_image_understanding_model(model_id: str, desc: str, req_mods: set, res_mods: set) -> bool: + searchable_text = f"{model_id} {desc.lower()}" + if _is_dashscope_image_generation_model(model_id, desc, req_mods, res_mods): + return False + if _is_dashscope_video_understanding_model(model_id, desc, req_mods, res_mods): + return False + if ("image" in req_mods or "video" in req_mods) and "text" in res_mods: + return True + return _is_dashscope_explicit_image_understanding_model(model_id) or _has_keyword( + searchable_text, DASHSCOPE_IMAGE_UNDERSTANDING_KEYWORDS + ) + + class DashScopeModelProvider(AbstractModelProvider): """Concrete implementation for DashScope (Aliyun) provider.""" @@ -57,6 +126,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: categorized_models = { "chat": [], # Maps to "llm" "vlm": [], # Maps to "vlm" + "vlm2": [], # Maps to image generation models + "vlm3": [], # Maps to video understanding models "embedding": [], # Maps to "embedding" / "multi_embedding" "rerank": [], # Maps to "rerank" "tts": [], # Maps to "tts" @@ -71,6 +142,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: metadata = model_obj.get('inference_metadata') or {} req_mod = metadata.get('request_modality', []) res_mod = metadata.get('response_modality', []) + req_mods = _modality_set(req_mod) + res_mods = _modality_set(res_mod) model_obj.setdefault("object", model_obj.get("object", "model")) model_obj.setdefault("owned_by", model_obj.get("owned_by", "dashscope")) cleaned_model = { @@ -107,8 +180,17 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: continue # 5. VLM - vision_mods = {'Image', 'Video'} - if (set(req_mod) & vision_mods) or (set(res_mod) & vision_mods) or '视觉' in desc: + if _is_dashscope_video_understanding_model(m_id, desc, req_mods, res_mods): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm3"}) + categorized_models['vlm3'].append(cleaned_model) + continue + + if _is_dashscope_image_generation_model(m_id, desc, req_mods, res_mods): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm2"}) + categorized_models['vlm2'].append(cleaned_model) + continue + + if _is_dashscope_image_understanding_model(m_id, desc, req_mods, res_mods): cleaned_model.update({"model_tag": "chat", "model_type": "vlm"}) categorized_models['vlm'].append(cleaned_model) continue @@ -124,7 +206,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: elif target_model_type in ("embedding", "multi_embedding"): return categorized_models["embedding"] elif target_model_type in categorized_models: - return categorized_models[target_model_type] + return [ + {**model, "model_type": target_model_type} + for model in categorized_models[target_model_type] + ] else: return [] except (httpx.HTTPStatusError, httpx.ConnectTimeout, httpx.ConnectError, Exception) as e: diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index ab4446c1b..be2bb9c71 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -9,6 +9,64 @@ from services.providers.base import AbstractModelProvider, _classify_provider_error +TOKENPONY_IMAGE_UNDERSTANDING_KEYWORDS = ( + "qwen-vl", + "qwen2-vl", + "qwen2.5-vl", + "qwen3-vl", + "qwen3.5-vl", + "qwen3.6-vl", + "-vl", + "vl-", + "vision", + "visual", + "ocr", + "gpt-4o", + "qwen3.6", + "qwen-3.6", +) +TOKENPONY_IMAGE_GENERATION_KEYWORDS = ( + "image", + "dall", + "flux", + "stable-diffusion", + "sdxl", + "midjourney", + "wanx", + "kolors", + "seedream", + "ideogram", + "recraft", +) +TOKENPONY_VIDEO_UNDERSTANDING_KEYWORDS = ("omni", "video") + + +def _has_keyword(text: str, keywords: tuple) -> bool: + return any(keyword in text for keyword in keywords) + + +def _is_tokenpony_explicit_image_understanding_model(model_id: str) -> bool: + return _has_keyword(model_id, TOKENPONY_IMAGE_UNDERSTANDING_KEYWORDS) + + +def _is_tokenpony_image_generation_model(model_id: str) -> bool: + if _is_tokenpony_explicit_image_understanding_model(model_id): + return False + return _has_keyword(model_id, TOKENPONY_IMAGE_GENERATION_KEYWORDS) + + +def _is_tokenpony_video_understanding_model(model_id: str) -> bool: + return _has_keyword(model_id, TOKENPONY_VIDEO_UNDERSTANDING_KEYWORDS) + + +def _is_tokenpony_image_understanding_model(model_id: str) -> bool: + if _is_tokenpony_image_generation_model(model_id): + return False + if _is_tokenpony_video_understanding_model(model_id): + return False + return _is_tokenpony_explicit_image_understanding_model(model_id) + + class TokenPonyModelProvider(AbstractModelProvider): """Concrete implementation for TokenPony provider.""" @@ -46,6 +104,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: categorized_models = { "chat": [], # Maps to "llm" "vlm": [], # Maps to "vlm" + "vlm2": [], # Maps to image generation models + "vlm3": [], # Maps to video understanding models "embedding": [], # Maps to "embedding" / "multi_embedding" "rerank": [], # Maps to "rerank" "tts": [], # Maps to "tts" @@ -86,9 +146,14 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: cleaned_model.update({"model_tag": "tts", "model_type": "tts"}) categorized_models['tts'].append(cleaned_model) - # 5. VLM (Vision Language Model / Image & Video Generation) - - elif any(keyword in m_id for keyword in ['-vl', 'vl-', 'ocr', 'vision']): + # 5. Multimodal models + elif _is_tokenpony_video_understanding_model(m_id): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm3"}) + categorized_models['vlm3'].append(cleaned_model) + elif _is_tokenpony_image_generation_model(m_id): + cleaned_model.update({"model_tag": "chat", "model_type": "vlm2"}) + categorized_models['vlm2'].append(cleaned_model) + elif _is_tokenpony_image_understanding_model(m_id): cleaned_model.update({"model_tag": "chat", "model_type": "vlm"}) categorized_models['vlm'].append(cleaned_model) @@ -104,7 +169,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: elif target_model_type in ("embedding", "multi_embedding"): return categorized_models["embedding"] elif target_model_type in categorized_models: - return categorized_models[target_model_type] + return [ + {**model, "model_type": target_model_type} + for model in categorized_models[target_model_type] + ] else: return [] diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index 80df3728c..993795c98 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -113,11 +113,10 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - const { isVlmAvailable, isEmbeddingAvailable, isMultiEmbeddingAvailable } = useConfig(); - const isEmbeddingOrMultiAvailable = isEmbeddingAvailable || isMultiEmbeddingAvailable; const { isImageUnderstandingAvailable, isVideoUnderstandingAvailable, + isEmbeddingAvailable, } = useConfig(); // Prefetch knowledge bases for KB tools @@ -383,10 +382,7 @@ export default function ToolManagement({ isImageUnderstandingAvailable, isVideoUnderstandingAvailable ); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( - tool.name, - isEmbeddingOrMultiAvailable - ); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly @@ -495,10 +491,7 @@ export default function ToolManagement({ isImageUnderstandingAvailable, isVideoUnderstandingAvailable ); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( - tool.name, - isEmbeddingOrMultiAvailable - ); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly diff --git a/sdk/nexent/core/tools/analyze_audio_tool.py b/sdk/nexent/core/tools/analyze_audio_tool.py index c7509a6c2..1e5439443 100644 --- a/sdk/nexent/core/tools/analyze_audio_tool.py +++ b/sdk/nexent/core/tools/analyze_audio_tool.py @@ -7,7 +7,7 @@ import logging from io import BytesIO -from typing import List +from typing import List, Optional from jinja2 import StrictUndefined, Template from pydantic import Field @@ -28,28 +28,29 @@ class AnalyzeAudioTool(Tool): """Tool for understanding and analyzing audio using the video understanding model.""" name = "analyze_audio" + skip_forward_signature_validation = True description = ( "This tool uses the configured video understanding model to understand audio based on your query and then returns an audio analysis result.\n" - "It is used to understand and analyze multiple audio files, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " + "It is used to understand and analyze one audio file, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " "HTTP, and HTTPS URLs.\n" "Use this tool when you want to retrieve information contained in audio and provide the audio URL and your query." ) description_zh = ( - "使用视频理解模型,根据你的提示词来理解音频,并返回音频分析结果。" - "可用于理解和分析多个音频文件,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + "使用视频理解模型,根据你的问题理解音频,并返回音频分析结果。" + "可用于理解和分析一个音频文件,支持 S3 URL(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" ) inputs = { - "audio_urls_list": { - "type": "array", - "description": "List of audio URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", - "description_zh": "列表形式输入音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。" + "audio_url": { + "type": "string", + "description": "Audio URL (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", + "description_zh": "音频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。", }, "query": { "type": "string", "description": "User's question to guide the audio analysis", - "description_zh": "用户用于指导音频分析的问题" - } + "description_zh": "用户的问题,用于指导音频分析", + }, } init_param_descriptions = { @@ -58,9 +59,9 @@ class AnalyzeAudioTool(Tool): "storage_client": {"description": "Storage client for downloading files"}, "validate_url_access": { "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" - } + }, } - output_type = "array" + output_type = "string" category = ToolCategory.MULTIMODAL.value tool_sign = ToolSign.MULTIMODAL_OPERATION.value @@ -94,10 +95,10 @@ def __init__( validate_callback = validate_url_access self.mm = LoadSaveObjectManager( storage_client=self.storage_client, - validate_url_access=validate_callback + validate_url_access=validate_callback, ) self.forward = self.mm.load_object( - input_names=["audio_urls_list"])(self._forward_impl) + input_names=["audio_url", "audio_urls_list"])(self._forward_impl) self.running_prompt_zh = "正在分析音频..." self.running_prompt_en = "Analyzing audio..." @@ -114,10 +115,14 @@ def _validate_audio_capable_model(self) -> None: "Please choose a Qwen3-Omni model for analyze_audio." ) - def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: - """Analyze audio files and return one result per audio input.""" + def _forward_impl( + self, + audio_url: Optional[bytes] = None, + query: str = "", + audio_urls_list: Optional[List[bytes]] = None) -> str: + """Analyze an audio file and return the result as a string.""" if self.vlm_model is None: - error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试" + error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试。" error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again." error_msg = error_msg_zh if self._is_chinese else error_msg_en logger.error(error_msg) @@ -128,12 +133,17 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en self.observer.add_message("", ProcessType.TOOL, running_prompt) - if audio_urls_list is None: - raise ValueError("audio_urls cannot be None") - if not isinstance(audio_urls_list, list): - raise ValueError("audio_urls must be a list of bytes") - if not audio_urls_list: - raise ValueError("audio_urls must contain at least one audio file") + if audio_url is not None: + audio_items = [audio_url] + else: + audio_items = audio_urls_list + + if audio_items is None: + raise ValueError("audio_url cannot be None") + if not isinstance(audio_items, list): + raise ValueError("audio_url must be bytes or audio_urls_list must be a list of bytes") + if not audio_items: + raise ValueError("audio_url must contain an audio file") language = self.observer.lang if self.observer else "en" prompts = get_prompt_template( @@ -143,7 +153,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: try: analysis_results: List[str] = [] - for index, audio_bytes in enumerate(audio_urls_list, start=1): + for index, audio_bytes in enumerate(audio_items, start=1): logger.info(f"Analyzing audio #{index}, query: {query}") content_type = detect_content_type_from_bytes(audio_bytes) if not content_type.startswith("audio/"): @@ -153,7 +163,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: response = self.vlm_model.analyze_audio( audio_input=audio_stream, system_prompt=system_prompt, - content_type=content_type + content_type=content_type, ) except Exception as e: error_msg_zh = f"音频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。" @@ -163,7 +173,7 @@ def _forward_impl(self, audio_urls_list: List[bytes], query: str) -> List[str]: analysis_results.append(response.content) - return analysis_results + return "\n\n".join(analysis_results) except Exception as e: logger.error(f"Error analyzing audio: {str(e)}", exc_info=True) raise Exception(f"Error analyzing audio: {str(e)}") diff --git a/sdk/nexent/core/tools/analyze_video_tool.py b/sdk/nexent/core/tools/analyze_video_tool.py index 3dc033551..e7bf84549 100644 --- a/sdk/nexent/core/tools/analyze_video_tool.py +++ b/sdk/nexent/core/tools/analyze_video_tool.py @@ -1,13 +1,13 @@ """ Analyze Video Tool -Analyze videos using the configured video understanding model. -Supports videos from S3, HTTP, and HTTPS URLs. +Analyze video using the configured video understanding model. +Supports video from S3, HTTP, and HTTPS URLs. """ import logging from io import BytesIO -from typing import List +from typing import List, Optional from jinja2 import StrictUndefined, Template from pydantic import Field @@ -25,31 +25,32 @@ class AnalyzeVideoTool(Tool): - """Tool for understanding and analyzing videos using the video understanding model.""" + """Tool for understanding and analyzing video using the video understanding model.""" name = "analyze_video" + skip_forward_signature_validation = True description = ( - "This tool uses the configured video understanding model to understand videos based on your query and then returns a video analysis result.\n" - "It is used to understand and analyze multiple videos, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " + "This tool uses the configured video understanding model to understand video based on your query and then returns a video analysis result.\n" + "It is used to understand and analyze one video, with sources supporting S3 URLs (s3://bucket/key or /bucket/key), " "HTTP, and HTTPS URLs.\n" - "Use this tool when you want to retrieve information contained in a video and provide the video's URL and your query." + "Use this tool when you want to retrieve information contained in a video and provide the video URL and your query." ) description_zh = ( - "使用视频理解模型,根据你的提示词来理解视频,并返回视频分析结果。" - "可用于理解和分析多个视频,支持 S3 URLs(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" + "使用视频理解模型,根据你的问题理解视频,并返回视频分析结果。" + "可用于理解和分析一个视频,支持 S3 URL(s3://bucket/key 或 /bucket/key)、HTTP 和 HTTPS URL。" ) inputs = { - "video_urls_list": { - "type": "array", - "description": "List of video URLs (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", - "description_zh": "列表形式输入视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。" + "video_url": { + "type": "string", + "description": "Video URL (S3, HTTP, or HTTPS). Supports s3://bucket/key, /bucket/key, http://, and https:// URLs.", + "description_zh": "视频 URL(S3、HTTP 或 HTTPS)。支持 s3://bucket/key、/bucket/key、http:// 和 https:// URL。", }, "query": { "type": "string", "description": "User's question to guide the video analysis", - "description_zh": "用户用于指导视频分析的问题" - } + "description_zh": "用户的问题,用于指导视频分析", + }, } init_param_descriptions = { @@ -58,9 +59,9 @@ class AnalyzeVideoTool(Tool): "storage_client": {"description": "Storage client for downloading files"}, "validate_url_access": { "description": "Callback function to validate URL access permissions (passed to LoadSaveObjectManager)" - } + }, } - output_type = "array" + output_type = "string" category = ToolCategory.MULTIMODAL.value tool_sign = ToolSign.MULTIMODAL_OPERATION.value @@ -94,18 +95,22 @@ def __init__( validate_callback = validate_url_access self.mm = LoadSaveObjectManager( storage_client=self.storage_client, - validate_url_access=validate_callback + validate_url_access=validate_callback, ) self.forward = self.mm.load_object( - input_names=["video_urls_list"])(self._forward_impl) + input_names=["video_url", "video_urls_list"])(self._forward_impl) self.running_prompt_zh = "正在分析视频..." self.running_prompt_en = "Analyzing video..." - def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: - """Analyze videos and return one result per video input.""" + def _forward_impl( + self, + video_url: Optional[bytes] = None, + query: str = "", + video_urls_list: Optional[List[bytes]] = None) -> str: + """Analyze a video and return the result as a string.""" if self.vlm_model is None: - error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试" + error_msg_zh = "视频理解模型未配置,请联系管理员配置视频理解模型后重试。" error_msg_en = "Video understanding model is not configured. Please contact your administrator to configure the video understanding model and try again." error_msg = error_msg_zh if self._is_chinese else error_msg_en logger.error(error_msg) @@ -115,12 +120,17 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: running_prompt = self.running_prompt_zh if self._is_chinese else self.running_prompt_en self.observer.add_message("", ProcessType.TOOL, running_prompt) - if video_urls_list is None: - raise ValueError("video_urls cannot be None") - if not isinstance(video_urls_list, list): - raise ValueError("video_urls must be a list of bytes") - if not video_urls_list: - raise ValueError("video_urls must contain at least one video") + if video_url is not None: + video_items = [video_url] + else: + video_items = video_urls_list + + if video_items is None: + raise ValueError("video_url cannot be None") + if not isinstance(video_items, list): + raise ValueError("video_url must be bytes or video_urls_list must be a list of bytes") + if not video_items: + raise ValueError("video_url must contain a video") language = self.observer.lang if self.observer else "en" prompts = get_prompt_template( @@ -130,7 +140,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: try: analysis_results: List[str] = [] - for index, video_bytes in enumerate(video_urls_list, start=1): + for index, video_bytes in enumerate(video_items, start=1): logger.info(f"Analyzing video #{index}, query: {query}") content_type = detect_content_type_from_bytes(video_bytes) if not content_type.startswith("video/"): @@ -140,7 +150,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: response = self.vlm_model.analyze_video( video_input=video_stream, system_prompt=system_prompt, - content_type=content_type + content_type=content_type, ) except Exception as e: error_msg_zh = f"视频{index}分析失败: {str(e)}。请检查视频理解模型配置是否正确。" @@ -150,7 +160,7 @@ def _forward_impl(self, video_urls_list: List[bytes], query: str) -> List[str]: analysis_results.append(response.content) - return analysis_results + return "\n\n".join(analysis_results) except Exception as e: logger.error(f"Error analyzing video: {str(e)}", exc_info=True) raise Exception(f"Error analyzing video: {str(e)}") diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py index 30229677a..5c6267040 100644 --- a/test/backend/services/providers/test_dashscope_provider.py +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -141,11 +141,27 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): "models": [ { "model": "qwen-vl-plus", - "description": "Vision language model", + "description": "Vision language model for image understanding", "inference_metadata": { "request_modality": ["Image", "Text"], "response_modality": ["Text"] } + }, + { + "model": "qwen3.6-27b", + "description": "Qwen 3.6 multimodal model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen-vl-max", + "description": "Qwen VL max model", + "inference_metadata": { + "request_modality": ["Image", "Text"], + "response_modality": ["Text", "Image"] + } } ] } @@ -167,11 +183,129 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): result = await provider.get_models(provider_config) + assert [model["id"] for model in result] == ["qwen-vl-plus", "qwen3.6-27b", "qwen-vl-max"] + assert all(model["model_type"] == "vlm" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + + @pytest.mark.asyncio + async def test_get_models_vlm2_only_returns_image_generation_models(self, mocker: MockFixture): + """Image generation slot only returns image-generation multimodal models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-vl-plus", + "description": "Vision language model", + "inference_metadata": { + "request_modality": ["Image", "Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen-image-max", + "description": "Image generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Image"] + } + }, + { + "model": "qwen-vl-max", + "description": "Qwen VL max model", + "inference_metadata": { + "request_modality": ["Image", "Text"], + "response_modality": ["Text", "Image"] + } + }, + { + "model": "qwen-plus", + "description": "Text generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + self._setup_mock_client(mocker, mock_response) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "vlm2", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + assert len(result) == 1 - assert result[0]["id"] == "qwen-vl-plus" - assert result[0]["model_type"] == "vlm" + assert result[0]["id"] == "qwen-image-max" + assert result[0]["model_type"] == "vlm2" assert result[0]["model_tag"] == "chat" + @pytest.mark.asyncio + async def test_get_models_vlm3_only_returns_video_understanding_models(self, mocker: MockFixture): + """Video understanding slot excludes image generation and text-only models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-image-max", + "description": "Image generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Image"] + } + }, + { + "model": "qwen-omni-turbo", + "description": "Video understanding model", + "inference_metadata": { + "request_modality": ["Video", "Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen3-omni-30b-a3b-instruct", + "description": "Omni multimodal model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + }, + { + "model": "qwen-plus", + "description": "Text generation model", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"] + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + self._setup_mock_client(mocker, mock_response) + + provider = DashScopeModelProvider() + provider_config = { + "model_type": "vlm3", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert [model["id"] for model in result] == ["qwen-omni-turbo", "qwen3-omni-30b-a3b-instruct"] + assert all(model["model_type"] == "vlm3" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + @pytest.mark.asyncio async def test_get_models_rerank_success(self, mocker: MockFixture): """Test successful model retrieval for rerank models.""" diff --git a/test/backend/services/providers/test_tokenpony_provider.py b/test/backend/services/providers/test_tokenpony_provider.py index e93d8ba7b..58e514dbb 100644 --- a/test/backend/services/providers/test_tokenpony_provider.py +++ b/test/backend/services/providers/test_tokenpony_provider.py @@ -126,6 +126,16 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): "id": "qwen-vl-plus", "object": "model", "owned_by": "qwen" + }, + { + "id": "qwen3.6-27b", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "qwen-vl-max", + "object": "model", + "owned_by": "qwen" } ] } @@ -155,11 +165,121 @@ async def test_get_models_vlm_success(self, mocker: MockFixture): result = await provider.get_models(provider_config) + assert [model["id"] for model in result] == ["qwen-vl-plus", "qwen3.6-27b", "qwen-vl-max"] + assert all(model["model_type"] == "vlm" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + + @pytest.mark.asyncio + async def test_get_models_vlm2_only_returns_image_generation_models(self, mocker: MockFixture): + """Image generation slot only returns image-generation multimodal models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "qwen-vl-plus", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "flux-image-pro", + "object": "model", + "owned_by": "flux" + }, + { + "id": "qwen-vl-max", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "qwen-plus", + "object": "model", + "owned_by": "qwen" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "vlm2", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + assert len(result) == 1 - assert result[0]["id"] == "qwen-vl-plus" - assert result[0]["model_type"] == "vlm" + assert result[0]["id"] == "flux-image-pro" + assert result[0]["model_type"] == "vlm2" assert result[0]["model_tag"] == "chat" + @pytest.mark.asyncio + async def test_get_models_vlm3_only_returns_video_understanding_models(self, mocker: MockFixture): + """Video understanding slot excludes image generation and text-only models.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "flux-image-pro", + "object": "model", + "owned_by": "flux" + }, + { + "id": "qwen-omni-video", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "qwen3-omni-30b-a3b-instruct", + "object": "model", + "owned_by": "qwen" + }, + { + "id": "qwen-plus", + "object": "model", + "owned_by": "qwen" + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + + provider = TokenPonyModelProvider() + provider_config = { + "model_type": "vlm3", + "api_key": "test-api-key" + } + + result = await provider.get_models(provider_config) + + assert [model["id"] for model in result] == ["qwen-omni-video", "qwen3-omni-30b-a3b-instruct"] + assert all(model["model_type"] == "vlm3" for model in result) + assert all(model["model_tag"] == "chat" for model in result) + @pytest.mark.asyncio async def test_get_models_rerank_success(self, mocker: MockFixture): """Test successful model retrieval for rerank models.""" diff --git a/test/backend/services/test_config_sync_service.py b/test/backend/services/test_config_sync_service.py index fe19ceb44..78bcb4cf8 100644 --- a/test/backend/services/test_config_sync_service.py +++ b/test/backend/services/test_config_sync_service.py @@ -1,6 +1,6 @@ import sys import types -import importlib.machinery +import importlib from unittest.mock import patch, MagicMock, call import pytest diff --git a/test/backend/services/test_config_sync_service_voice.py b/test/backend/services/test_config_sync_service_voice.py index c1e1a7bd9..213fbcdf3 100644 --- a/test/backend/services/test_config_sync_service_voice.py +++ b/test/backend/services/test_config_sync_service_voice.py @@ -5,6 +5,7 @@ import importlib import sys import types +import importlib from unittest.mock import patch, MagicMock import pytest diff --git a/test/backend/services/test_model_health_service.py b/test/backend/services/test_model_health_service.py index 139ded1a2..559677527 100644 --- a/test/backend/services/test_model_health_service.py +++ b/test/backend/services/test_model_health_service.py @@ -24,6 +24,7 @@ def __getattr__(cls, key): sys.modules['utils'] = MockModule() sys.modules['utils.auth_utils'] = MockModule() sys.modules['utils.config_utils'] = MockModule() +sys.modules['utils.memory_utils'] = MockModule() sys.modules['utils.model_name_utils'] = MockModule() sys.modules['consts'] = MockModule() consts_const_module = MockModule() @@ -221,6 +222,51 @@ async def test_perform_connectivity_check_vlm(): mock_model_instance.check_connectivity.assert_called_once() +@pytest.mark.asyncio +async def test_perform_connectivity_check_dashscope_multimodal_uses_provider_catalog(): + model_provider_service = types.ModuleType("services.model_provider_service") + model_provider_service.get_provider_models = mock.AsyncMock(return_value=[ + {"id": "qwen-image-max", "model_type": "vlm2"}, + ]) + + with mock.patch.dict(sys.modules, {"services.model_provider_service": model_provider_service}), \ + mock.patch("backend.services.model_health_service.OpenAIVLModel") as mock_model: + result = await _perform_connectivity_check( + "qwen-image-max", + "vlm2", + "https://dashscope.aliyuncs.com/compatible-mode/v1/", + "test-key", + model_factory="dashscope", + ) + + assert result is True + model_provider_service.get_provider_models.assert_awaited_once_with({ + "provider": "dashscope", + "model_type": "vlm2", + "api_key": "test-key", + }) + mock_model.assert_not_called() + + +@pytest.mark.asyncio +async def test_perform_connectivity_check_tokenpony_multimodal_catalog_error_returns_false(): + model_provider_service = types.ModuleType("services.model_provider_service") + model_provider_service.get_provider_models = mock.AsyncMock(return_value=[ + {"_error": "authentication_failed", "_message": "Invalid API key"}, + ]) + + with mock.patch.dict(sys.modules, {"services.model_provider_service": model_provider_service}): + result = await _perform_connectivity_check( + "qwen-vl-plus", + "vlm3", + "https://api.tokenpony.cn/v1/", + "bad-key", + model_factory="tokenpony", + ) + + assert result is False + + @pytest.mark.asyncio async def test_perform_connectivity_check_stt(): # Setup diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 83a070fe0..be9c5e406 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -83,7 +83,22 @@ def get_value(status): return status or _ModelConnectStatusEnum.NOT_DETECTED.value +class _ToolValidateRequest: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class _ProcessParams: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def model_dump(self, *args, **kwargs): + return dict(self.__dict__) + + consts_model_mod.ModelConnectStatusEnum = _ModelConnectStatusEnum +consts_model_mod.ToolValidateRequest = _ToolValidateRequest +consts_model_mod.ProcessParams = _ProcessParams sys.modules["consts.model"] = consts_model_mod if "consts" not in sys.modules: sys.modules["consts"] = types.ModuleType("consts") @@ -93,6 +108,23 @@ def get_value(status): consts_const_mod.LOCALHOST_IP = "127.0.0.1" consts_const_mod.LOCALHOST_NAME = "localhost" consts_const_mod.DOCKER_INTERNAL_HOST = "host.docker.internal" +consts_const_mod.DATA_PROCESS_SERVICE = "http://data-process" +consts_const_mod.FILE_PREVIEW_SIZE_LIMIT = 100 * 1024 * 1024 +consts_const_mod.MAX_CONCURRENT_UPLOADS = 5 +consts_const_mod.OFFICE_MIME_TYPES = [] +consts_const_mod.UPLOAD_FOLDER = "uploads" +consts_const_mod.LOCAL_MCP_SERVER = "http://local-mcp" +consts_const_mod.MCP_MANAGEMENT_API = "http://mcp-management" +consts_const_mod.LIBREOFFICE_PROFILE_DIR = "libreoffice-profile" +consts_const_mod.DEFAULT_TENANT_ID = "tenant_id" +consts_const_mod.DEFAULT_USER_ID = "user_id" +consts_const_mod.IS_SPEED_MODE = False +consts_const_mod.SUPABASE_JWT_SECRET = "test-secret" +consts_const_mod.SUPABASE_URL = "http://supabase" +consts_const_mod.SUPABASE_KEY = "supabase-key" +consts_const_mod.SERVICE_ROLE_KEY = "service-role-key" +consts_const_mod.DEBUG_JWT_EXPIRE_SECONDS = 3600 +consts_const_mod.LANGUAGE = "zh" # Fields required by utils.memory_utils and services.vectordatabase_service consts_const_mod.MODEL_CONFIG_MAPPING = { "llm": "LLM_ID", "embedding": "EMBEDDING_ID"} @@ -193,9 +225,23 @@ def _sort_models_by_id(model_list): utils_name_mod.sort_models_by_id = _sort_models_by_id sys.modules["utils.model_name_utils"] = utils_name_mod +# Stub utils.file_management_utils so file_management_service can be imported +# by other tests in the same pytest process without pulling auth/database deps. +utils_file_mgmt_mod = types.ModuleType("utils.file_management_utils") + + +async def _save_upload_file(*args, **kwargs): + return None + + +utils_file_mgmt_mod.save_upload_file = _save_upload_file +sys.modules["utils.file_management_utils"] = utils_file_mgmt_mod + # Stub database.model_management_db to avoid importing heavy DB client database_mod = types.ModuleType("database") +database_mod.__path__ = [] db_mm_mod = types.ModuleType("database.model_management_db") +db_attachment_mod = types.ModuleType("database.attachment_db") def _noop(*args, **kwargs): @@ -245,6 +291,22 @@ def _get_model_by_model_id(model_id: int, tenant_id: str): db_mm_mod.update_model_record = _noop sys.modules["database"] = database_mod sys.modules["database.model_management_db"] = db_mm_mod +for _attachment_func in [ + "copy_file", + "delete_file", + "file_exists", + "get_content_type", + "get_file_range", + "get_file_size_from_minio", + "get_file_stream", + "get_file_stream_raw", + "get_file_url", + "list_files", + "upload_fileobj", +]: + setattr(db_attachment_mod, _attachment_func, _noop) +sys.modules["database.attachment_db"] = db_attachment_mod +setattr(database_mod, "attachment_db", db_attachment_mod) # Stub database.tenant_config_db required by utils.config_utils db_tenant_cfg_mod = types.ModuleType("database.tenant_config_db") @@ -286,10 +348,15 @@ def _update_config_by_tenant_config_id(*args, **kwargs): services_vdb_mod = types.ModuleType("services.vectordatabase_service") +class _ElasticSearchService: + pass + + def _get_vector_db_core(): return object() +services_vdb_mod.ElasticSearchService = _ElasticSearchService services_vdb_mod.get_vector_db_core = _get_vector_db_core sys.modules["services.vectordatabase_service"] = services_vdb_mod @@ -340,7 +407,7 @@ def import_svc(): async def test_create_model_for_tenant_success_llm(): svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None) as mock_get_by_display, \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]) as mock_get_by_display, \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("huggingface", "llama")): @@ -367,7 +434,7 @@ async def test_create_model_for_tenant_open_router_disables_ssl(): """When base_url contains 'open/router' ssl_verify should be set to False and model_factory to 'modelengine'.""" svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("modelengine", "m")): @@ -394,7 +461,7 @@ async def test_create_model_for_tenant_open_router_disables_ssl(): async def test_create_model_for_tenant_conflict_raises(): svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value={"model_id": "exists"}): + with mock.patch.object(svc, "get_models_by_display_name", return_value=[{"model_id": "exists", "model_type": "llm"}]): user_id = "u1" tenant_id = "t1" model_data = { @@ -415,7 +482,7 @@ async def test_create_model_for_tenant_display_name_conflict_valueerror(): svc = import_svc() existing_model = {"model_id": 1, "display_name": "existing_name"} - with mock.patch.object(svc, "get_model_by_display_name", return_value=existing_model): + with mock.patch.object(svc, "get_models_by_display_name", return_value=[existing_model]): user_id = "u1" tenant_id = "t1" model_data = { @@ -432,11 +499,58 @@ async def test_create_model_for_tenant_display_name_conflict_valueerror(): assert "existing_name" in str(exc.value) +@pytest.mark.asyncio +async def test_create_model_for_tenant_allows_same_display_name_across_multimodal_slots(): + """Image understanding, image generation, and video understanding are separate slots.""" + svc = import_svc() + + existing_models = [ + {"model_id": 1, "display_name": "Qwen3.6-27B", "model_type": "vlm"}, + {"model_id": 2, "display_name": "Qwen3.6-27B", "model_type": "vlm3"}, + ] + + with mock.patch.object(svc, "get_models_by_display_name", return_value=existing_models), \ + mock.patch.object(svc, "create_model_record") as mock_create, \ + mock.patch.object(svc, "split_repo_name", return_value=("Qwen", "Qwen3.6-27B")): + + model_data = { + "model_name": "Qwen/Qwen3.6-27B", + "display_name": "Qwen3.6-27B", + "base_url": "https://api.example.com/v1", + "model_type": "vlm2", + } + + await svc.create_model_for_tenant("u1", "t1", model_data) + + mock_create.assert_called_once() + + +@pytest.mark.asyncio +async def test_create_model_for_tenant_blocks_duplicate_within_same_multimodal_slot(): + svc = import_svc() + + with mock.patch.object( + svc, + "get_models_by_display_name", + return_value=[{"model_id": 1, "display_name": "Qwen3.6-27B", "model_type": "vlm"}], + ): + model_data = { + "model_name": "Qwen/Qwen3.6-27B", + "display_name": "Qwen3.6-27B", + "base_url": "https://api.example.com/v1", + "model_type": "vlm", + } + + with pytest.raises(Exception) as exc: + await svc.create_model_for_tenant("u1", "t1", model_data) + assert "already in use" in str(exc.value) + + @pytest.mark.asyncio async def test_create_model_for_tenant_multi_embedding_creates_two_records(): svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): @@ -458,7 +572,7 @@ async def test_create_model_for_tenant_multi_embedding_creates_two_records(): async def test_create_model_for_tenant_embedding_sets_dimension(): svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=1536)) as mock_dim, \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-ada-002")): @@ -484,7 +598,7 @@ async def test_create_model_for_tenant_embedding_sets_default_chunk_batch(): """chunk_batch defaults to 10 when not provided for embedding models.""" svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "text-embedding-3-small")): @@ -513,7 +627,7 @@ async def test_create_model_for_tenant_multi_embedding_sets_default_chunk_batch( """chunk_batch defaults to 10 when not provided for multi_embedding models (covers line 79).""" svc = import_svc() - with mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + with mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "embedding_dimension_check", new=mock.AsyncMock(return_value=512)) as mock_dim, \ mock.patch.object(svc, "create_model_record") as mock_create, \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "clip")): @@ -591,7 +705,7 @@ async def test_batch_create_models_for_tenant_dashscope_provider(): mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", return_value=("qwen", "qwen-turbo")), \ mock.patch.object(svc, "add_repo_to_name", return_value="qwen/qwen-turbo"), \ - mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ mock.patch.object(svc, "create_model_record", return_value=True): @@ -617,7 +731,7 @@ async def test_batch_create_models_for_tenant_tokenpony_provider(): mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", return_value=("gpt", "gpt-4o")), \ mock.patch.object(svc, "add_repo_to_name", return_value="gpt/gpt-4o"), \ - mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 2})), \ mock.patch.object(svc, "create_model_record", return_value=True): @@ -650,7 +764,7 @@ async def test_batch_create_models_for_tenant_other_provider(): mock.patch.object(svc, "delete_model_record"), \ mock.patch.object(svc, "split_repo_name", return_value=("openai", "gpt-4")), \ mock.patch.object(svc, "add_repo_to_name", return_value="openai/gpt-4"), \ - mock.patch.object(svc, "get_model_by_display_name", return_value=None), \ + mock.patch.object(svc, "get_models_by_display_name", return_value=[]), \ mock.patch.object(svc, "prepare_model_dict", new=mock.AsyncMock(return_value={"model_id": 1})), \ mock.patch.object(svc, "create_model_record", return_value=True): diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index 4c742162c..29d2f75f6 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -427,33 +427,29 @@ def validate(self): # Expose on parent package for patch resolution setattr(sys.modules['services'], service_name, service_module) -# Load actual backend modules so that patch targets resolve correctly -import importlib # noqa: E402 -backend_module = importlib.import_module('backend') -sys.modules['backend'] = backend_module -backend_database_module = importlib.import_module('backend.database') -sys.modules['backend.database'] = backend_database_module -backend_database_client_module = importlib.import_module( - 'backend.database.client') -sys.modules['backend.database.client'] = backend_database_client_module -backend_services_module = importlib.import_module( - 'backend.services.tool_configuration_service') -# Ensure services package can resolve tool_configuration_service for patching -sys.modules['services.tool_configuration_service'] = backend_services_module - -# Mock services modules +# Mock services modules before importing tool_configuration_service so absolute +# imports inside that module do not walk into real service dependency chains. sys.modules['services'] = _create_package_mock('services') services_modules = { - 'file_management_service': {'get_llm_model': MagicMock(), 'validate_urls_access': MagicMock()}, - 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_embedding_model_by_index_name': MagicMock(), - 'get_rerank_model': MagicMock(), 'get_vector_db_core': MagicMock(), - 'ElasticSearchService': MagicMock()}, - 'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()}, + 'file_management_service': { + 'get_llm_model': MagicMock(), + 'validate_urls_access': MagicMock(return_value=True), + }, + 'vectordatabase_service': { + 'get_embedding_model': MagicMock(), + 'get_embedding_model_by_index_name': MagicMock(), + 'get_rerank_model': MagicMock(), + 'get_vector_db_core': MagicMock(), + 'ElasticSearchService': MagicMock(), + }, + 'tenant_config_service': { + 'get_selected_knowledge_list': MagicMock(), + 'build_knowledge_name_mapping': MagicMock(), + }, 'image_service': { 'get_vlm_model': MagicMock(), 'get_video_understanding_model': MagicMock(), }, - 'redis_service': {'get_redis_service': MagicMock()}, } for service_name, attrs in services_modules.items(): service_module = types.ModuleType(f'services.{service_name}') @@ -463,47 +459,19 @@ def validate(self): # Expose on parent package for patch resolution setattr(sys.modules['services'], service_name, service_module) -# Also expose selected service stubs under backend.services.* so patch decorators -# don't import heavy real modules during collection. -try: - import backend.services as backend_services_pkg -except Exception: - backend_services_pkg = types.ModuleType("backend.services") - sys.modules["backend.services"] = backend_services_pkg -for service_name, service_module in [ - ("file_management_service", sys.modules["services.file_management_service"]), -]: - setattr(backend_services_pkg, service_name, service_module) - sys.modules[f"backend.services.{service_name}"] = service_module - -# Build a deterministic backend.services.file_management_service stub used by -# TestGetLlmModel so cross-file module monkeypatching does not affect imports. -backend_file_mgmt_module = types.ModuleType("backend.services.file_management_service") -backend_file_mgmt_module.MODEL_CONFIG_MAPPING = {"llm": "llm"} -backend_file_mgmt_module.tenant_config_manager = MagicMock() -backend_file_mgmt_module.get_model_name_from_config = MagicMock(return_value="gpt-4") -backend_file_mgmt_module.MessageObserver = MagicMock() -backend_file_mgmt_module.OpenAILongContextModel = MagicMock() -backend_file_mgmt_module.validate_urls_access = MagicMock() - -def _stub_get_llm_model(tenant_id): - cfg_key = backend_file_mgmt_module.MODEL_CONFIG_MAPPING["llm"] - model_config = backend_file_mgmt_module.tenant_config_manager.get_model_config( - key=cfg_key, tenant_id=tenant_id - ) - observer = backend_file_mgmt_module.MessageObserver() - return backend_file_mgmt_module.OpenAILongContextModel( - observer=observer, - model_id=backend_file_mgmt_module.get_model_name_from_config(model_config), - api_base=model_config.get("base_url"), - api_key=model_config.get("api_key"), - max_context_tokens=model_config.get("max_tokens"), - ssl_verify=model_config.get("ssl_verify", True), - ) - -backend_file_mgmt_module.get_llm_model = _stub_get_llm_model -sys.modules["backend.services.file_management_service"] = backend_file_mgmt_module -setattr(backend_services_pkg, "file_management_service", backend_file_mgmt_module) +# Load actual backend modules so that patch targets resolve correctly +import importlib # noqa: E402 +backend_module = importlib.import_module('backend') +sys.modules['backend'] = backend_module +backend_database_module = importlib.import_module('backend.database') +sys.modules['backend.database'] = backend_database_module +backend_database_client_module = importlib.import_module( + 'backend.database.client') +sys.modules['backend.database.client'] = backend_database_client_module +backend_services_module = importlib.import_module( + 'backend.services.tool_configuration_service') +# Ensure services package can resolve tool_configuration_service for patching +sys.modules['services.tool_configuration_service'] = backend_services_module # Patch storage factory and MinIO config validation to avoid errors during initialization # These patches must be started before any imports that use MinioClient @@ -3643,6 +3611,7 @@ def test_get_llm_model_success(self, mock_tenant_config, mock_get_model_name, mo api_key="test_api_key", max_context_tokens=4096, ssl_verify=True, + timeout_seconds=None, ) @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"}) @@ -3682,6 +3651,42 @@ def test_get_llm_model_with_missing_config_values(self, mock_tenant_config, mock call_kwargs = mock_openai_model.call_args[1] assert call_kwargs["api_key"] is None assert call_kwargs["max_context_tokens"] is None + assert call_kwargs["timeout_seconds"] is None + + @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"}) + @patch('backend.services.file_management_service.MessageObserver') + @patch('backend.services.file_management_service.OpenAILongContextModel') + @patch('backend.services.file_management_service.get_model_name_from_config') + @patch('backend.services.file_management_service.tenant_config_manager') + def test_get_llm_model_with_timeout_seconds(self, mock_tenant_config, mock_get_model_name, mock_openai_model, mock_message_observer): + """Test get_llm_model passes configured timeout_seconds.""" + from backend.services.file_management_service import get_llm_model + + mock_config = { + "base_url": "http://api.example.com", + "api_key": "test_api_key", + "max_tokens": 4096, + "timeout_seconds": 30, + } + mock_tenant_config.get_model_config.return_value = mock_config + mock_get_model_name.return_value = "gpt-4" + mock_observer_instance = Mock() + mock_message_observer.return_value = mock_observer_instance + mock_model_instance = Mock() + mock_openai_model.return_value = mock_model_instance + + result = get_llm_model("tenant123") + + assert result == mock_model_instance + mock_openai_model.assert_called_once_with( + observer=mock_observer_instance, + model_id="gpt-4", + api_base="http://api.example.com", + api_key="test_api_key", + max_context_tokens=4096, + ssl_verify=True, + timeout_seconds=30, + ) @patch('backend.services.file_management_service.MODEL_CONFIG_MAPPING', {"llm": "llm_config_key"}) @patch('backend.services.file_management_service.MessageObserver') diff --git a/test/sdk/core/tools/test_analyze_audio_video_tool.py b/test/sdk/core/tools/test_analyze_audio_video_tool.py index 94401b61d..7369ddfb2 100644 --- a/test/sdk/core/tools/test_analyze_audio_video_tool.py +++ b/test/sdk/core/tools/test_analyze_audio_video_tool.py @@ -44,9 +44,9 @@ def _fake_get_prompt(template_type, language=None, **_): storage_client=mock_storage_client, ) - result = tool._forward_impl([b"ID3audio-bytes"], "what happened?") + result = tool._forward_impl(audio_url=b"ID3audio-bytes", query="what happened?") - assert result == ["audio result"] + assert result == "audio result" assert calls == [("analyze_audio", "en")] mock_vlm_model.analyze_audio.assert_called_once() call_kwargs = mock_vlm_model.analyze_audio.call_args.kwargs @@ -55,6 +55,30 @@ def _fake_get_prompt(template_type, language=None, **_): observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing audio...") +def test_analyze_audio_schema_uses_single_url(): + assert "audio_url" in AnalyzeAudioTool.inputs + assert "audio_urls_list" not in AnalyzeAudioTool.inputs + assert AnalyzeAudioTool.output_type == "string" + + +def test_analyze_audio_accepts_legacy_url_list(observer_en, mock_vlm_model, mock_storage_client, monkeypatch): + monkeypatch.setattr( + analyze_audio_tool, + "get_prompt_template", + lambda template_type, language=None, **_: {"system_prompt": "Analyze audio for {{ query }}"}, + ) + mock_vlm_model.analyze_audio.return_value = SimpleNamespace(content="audio result") + tool = AnalyzeAudioTool( + observer=observer_en, + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + result = tool._forward_impl(audio_urls_list=[b"ID3audio-bytes"], query="what happened?") + + assert result == "audio result" + + def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_storage_client): vlm_model = SimpleNamespace( model_id="Qwen/Qwen3-VL-32B-Instruct", @@ -67,7 +91,7 @@ def test_analyze_audio_rejects_siliconflow_non_omni_model(observer_en, mock_stor ) with pytest.raises(ValueError) as exc_info: - tool._forward_impl([b"ID3audio-bytes"], "what happened?") + tool._forward_impl(audio_url=b"ID3audio-bytes", query="what happened?") assert "Please choose a Qwen3-Omni model" in str(exc_info.value) @@ -87,9 +111,9 @@ def _fake_get_prompt(template_type, language=None, **_): storage_client=mock_storage_client, ) - result = tool._forward_impl([b"\x00\x00\x00\x18ftypmp42video-bytes"], "what happened?") + result = tool._forward_impl(video_url=b"\x00\x00\x00\x18ftypmp42video-bytes", query="what happened?") - assert result == ["video result"] + assert result == "video result" assert calls == [("analyze_video", "en")] mock_vlm_model.analyze_video.assert_called_once() call_kwargs = mock_vlm_model.analyze_video.call_args.kwargs @@ -98,6 +122,30 @@ def _fake_get_prompt(template_type, language=None, **_): observer_en.add_message.assert_called_once_with("", ProcessType.TOOL, "Analyzing video...") +def test_analyze_video_schema_uses_single_url(): + assert "video_url" in AnalyzeVideoTool.inputs + assert "video_urls_list" not in AnalyzeVideoTool.inputs + assert AnalyzeVideoTool.output_type == "string" + + +def test_analyze_video_accepts_legacy_url_list(observer_en, mock_vlm_model, mock_storage_client, monkeypatch): + monkeypatch.setattr( + analyze_video_tool, + "get_prompt_template", + lambda template_type, language=None, **_: {"system_prompt": "Analyze video for {{ query }}"}, + ) + mock_vlm_model.analyze_video.return_value = SimpleNamespace(content="video result") + tool = AnalyzeVideoTool( + observer=observer_en, + vlm_model=mock_vlm_model, + storage_client=mock_storage_client, + ) + + result = tool._forward_impl(video_urls_list=[b"\x00\x00\x00\x18ftypmp42video-bytes"], query="what happened?") + + assert result == "video result" + + @pytest.mark.parametrize( "tool_class,input_name,error_text", [