diff --git a/.gitignore b/.gitignore index f526ef422..58fd44f05 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,8 @@ wheels/ /package /temp **/tmp/ +.env* +.claude-trace/ /apps/agentfabric/tmp/ MANIFEST @@ -148,6 +150,7 @@ apps/agentfabric/config/local_user/* *.pt .run/ .run/* +.claude* # ast template ast_index_file.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 00657312b..64a6fc0e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,24 +1,23 @@ - - - +default_language_version: + python: python3.11 repos: - repo: https://github.com/pycqa/flake8.git - rev: 4.0.0 + rev: 7.3.0 hooks: - id: flake8 exclude: ^(thirdparty/|examples/|tests/|projects/agent_skills/|projects/fin_research/examples/|ms_agent/utils/prompts\.py) - repo: https://github.com/PyCQA/isort.git - rev: 4.3.21 + rev: 5.13.2 hooks: - id: isort exclude: ^(examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) - repo: https://github.com/pre-commit/mirrors-yapf.git - rev: v0.30.0 + rev: v0.32.0 hooks: - id: yapf exclude: ^(thirdparty/|examples/|projects/fin_research/examples/|projects/agent_skills/|tests/) - repo: https://github.com/pre-commit/pre-commit-hooks.git - rev: v3.1.0 + rev: v5.0.0 hooks: - id: trailing-whitespace exclude: ^(thirdparty/|tests/|projects/fin_research/examples/|projects/agent_skills/) diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index 06f4d130a..ce1ffd1f2 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -102,6 +102,14 @@ tools: url: https://mcp.api-inference.modelscope.net/xxx/sse exclude: - map_geo + # Local codebase / document search (sirchmunk), exposed as the `localsearch` tool + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + mode: FAST + # Optional: llm_api_key, llm_base_url, llm_model_name (else inherited from `llm`) ``` For the complete list of supported tools and custom tools, please refer to [here](./Tools.md) @@ -220,19 +228,19 @@ In addition to yaml configuration, MS-Agent also supports several additional com > Any configuration in agent.yaml can be passed in with new values via command line, and also supports reading from environment variables with the same name (case insensitive), for example `--llm.modelscope_api_key xxx-xxx`. -- knowledge_search_paths: Knowledge search paths, comma-separated multiple paths. When provided, automatically enables SirchmunkSearch for knowledge retrieval, with LLM configuration automatically inherited from the `llm` module. +- knowledge_search_paths: Comma-separated local search paths. Merges into `tools.localsearch.paths` and registers the **`localsearch`** tool (sirchmunk) for on-demand use by the model—not automatic per-turn injection. LLM settings are inherited from the `llm` module unless you set `tools.localsearch.llm_*` fields. ### Quick Start for Knowledge Search -Use the `--knowledge_search_paths` parameter to quickly enable knowledge search based on local documents: +Use `--knowledge_search_paths` or define `tools.localsearch` in yaml so the model can call `localsearch` when needed: ```bash # Using default agent.yaml configuration, automatically reuses LLM settings -ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "./src,./docs" +ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "/path/to/docs" # Specify configuration file ms-agent run --config /path/to/agent.yaml --query "your question" --knowledge_search_paths "/path/to/docs" ``` LLM-related parameters (api_key, base_url, model) are automatically inherited from the `llm` module in the configuration file, no need to configure them repeatedly. -If you need to use independent LLM configuration in the `knowledge_search` module, you can explicitly configure `knowledge_search.llm_api_key` and other parameters in the yaml. +For a dedicated sirchmunk LLM, set `tools.localsearch.llm_api_key`, `llm_base_url`, and `llm_model_name` in yaml. Legacy top-level `knowledge_search` with the same keys is still read for backward compatibility. diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index f3dd74203..5ebfb01e5 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -102,6 +102,14 @@ tools: url: https://mcp.api-inference.modelscope.net/xxx/sse exclude: - map_geo + # 本地代码库/文档搜索(sirchmunk),对应模型可调用的 `localsearch` 工具 + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + mode: FAST + # 可选:llm_api_key、llm_base_url、llm_model_name(不填则从 `llm` 继承) ``` 支持的完整工具列表,以及自定义工具请参考 [这里](./tools) @@ -218,13 +226,13 @@ handler: custom_handler } } ``` -- knowledge_search_paths: 知识搜索路径,逗号分隔的多个路径。传入后会自动启用 SirchmunkSearch 进行知识检索,LLM 配置自动从 `llm` 模块复用 +- knowledge_search_paths: 知识搜索路径,逗号分隔。会合并到 `tools.localsearch.paths` 并注册 **`localsearch`** 工具(sirchmunk),由模型按需调用,不再在每轮自动注入上下文;除非配置 `tools.localsearch.llm_*`,否则 LLM 从 `llm` 模块复用 > agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 ### 知识搜索快速使用 -通过 `--knowledge_search_paths` 参数,可以快速启用基于本地文档的知识搜索: +通过 `--knowledge_search_paths` 或在 yaml 中配置 `tools.localsearch`,启用本地知识搜索(模型按需调用 `localsearch`): ```bash # 使用默认 agent.yaml 配置,自动复用 LLM 设置 @@ -235,4 +243,4 @@ ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_sea ``` LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 -如果需要在 `knowledge_search` 模块中使用独立的 LLM 配置,可以在 yaml 中显式配置 `knowledge_search.llm_api_key` 等参数。 +若 sirchmunk 需独立 LLM,可在 yaml 的 `tools.localsearch` 下设置 `llm_api_key`、`llm_base_url`、`llm_model_name`。仍支持旧版顶层 `knowledge_search` 相同字段,以便迁移。 diff --git a/examples/knowledge_search/agent.yaml.example b/examples/knowledge_search/agent.yaml.example deleted file mode 100644 index cc11a8a3d..000000000 --- a/examples/knowledge_search/agent.yaml.example +++ /dev/null @@ -1,86 +0,0 @@ -# Sirchmunk Knowledge Search 配置示例 -# Sirchmunk Knowledge Search Configuration Example - -# 在您的 agent.yaml 或 workflow.yaml 中添加以下配置: - -llm: - service: modelscope - model: Qwen/Qwen3-235B-A22B-Instruct-2507 - modelscope_api_key: - modelscope_base_url: https://api-inference.modelscope.cn/v1 - -generation_config: - temperature: 0.3 - top_k: 20 - stream: true - -# Knowledge Search 配置(可选) -# 用于在本地代码库中搜索相关信息 -knowledge_search: - # 必选:要搜索的路径列表 - paths: - - ./src - - ./docs - - # 可选:sirchmunk 工作目录,用于缓存 - work_path: ./.sirchmunk - - # 可选:LLM 配置(如不配置则使用上面 llm 的配置) - llm_api_key: - llm_base_url: https://api.openai.com/v1 - llm_model_name: gpt-4o-mini - - # 可选:Embedding 模型 - embedding_model: text-embedding-3-small - - # 可选:聚类相似度阈值 - cluster_sim_threshold: 0.85 - - # 可选:聚类 TopK - cluster_sim_top_k: 3 - - # 可选:是否重用之前的知识 - reuse_knowledge: true - - # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) - mode: FAST - - # 可选:最大循环次数 - max_loops: 10 - - # 可选:最大 token 预算 - max_token_budget: 128000 - -prompt: - system: | - You are an assistant that helps me complete tasks. - -max_chat_round: 9999 - -# 使用说明: -# 1. 配置 knowledge_search 后,LLMAgent 会在处理用户请求时自动搜索本地代码库 -# 2. 搜索结果会自动添加到 user message 的 search_result 和 searching_detail 字段 -# 3. search_result 包含搜索到的相关文档,会作为上下文提供给 LLM -# 4. searching_detail 包含搜索日志和元数据,可用于前端展示 -# -# Python 使用示例: -# ```python -# from ms_agent import LLMAgent -# from ms_agent.config import Config -# -# config = Config.from_task('path/to/agent.yaml') -# agent = LLMAgent(config=config) -# result = await agent.run('如何实现用户认证功能?') -# -# # 获取搜索详情(用于前端展示) -# for msg in result: -# if msg.role == 'user': -# print(f"Search logs: {msg.searching_detail}") -# print(f"Search results: {msg.search_result}") -# ``` -# -# CLI 测试命令: -# export LLM_API_KEY="your-api-key" -# export LLM_BASE_URL="https://api.openai.com/v1" -# export LLM_MODEL_NAME="gpt-4o-mini" -# python tests/knowledge_search/test_cli.py --query "你的问题" diff --git a/ms-agent-skills/scripts/check_ms_agent.py b/ms-agent-skills/scripts/check_ms_agent.py index 64e4668fa..59f56c3f7 100644 --- a/ms-agent-skills/scripts/check_ms_agent.py +++ b/ms-agent-skills/scripts/check_ms_agent.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import subprocess import sys -import json - def check_import() -> dict: """Check that ms_agent is importable.""" diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index 9e17c8c44..ba568d04c 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -1,12 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod +from omegaconf import DictConfig from typing import Any, AsyncGenerator, List, Tuple, Union from ms_agent.llm import Message from ms_agent.utils import read_history, save_history from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_RETRY_COUNT -from omegaconf import DictConfig class Agent(ABC): @@ -74,6 +74,15 @@ def save_history(self, messages: Any, **kwargs): return save_history(self.output_dir, self.tag, self.config, messages) + def list_snapshots(self) -> list: + """Return snapshots for this agent's output_dir, most recent first.""" + from ms_agent.utils.snapshot import list_snapshots + return list_snapshots(self.output_dir) + + def rollback(self, commit_hash: str) -> bool: + """Restore output_dir to a previous snapshot and truncate history.""" + raise NotImplementedError() + def next_flow(self, idx: int) -> int: """Used in workflow, decide which agent goes next.""" return idx + 1 diff --git a/ms_agent/agent/code_agent.py b/ms_agent/agent/code_agent.py index 33b2b63a4..63d378867 100644 --- a/ms_agent/agent/code_agent.py +++ b/ms_agent/agent/code_agent.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import Any, List, Union from ms_agent.llm import Message -from omegaconf import DictConfig - from .base import Agent diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 5f2ddf2e7..cdca437e0 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,18 +2,18 @@ import asyncio import importlib import inspect +import json import os.path import sys import threading import uuid from contextlib import contextmanager from copy import deepcopy +from omegaconf import DictConfig, OmegaConf from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, ToolResult from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping @@ -24,13 +24,15 @@ from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger -from omegaconf import DictConfig, OmegaConf - +from ms_agent.utils.snapshot import take_snapshot +from ms_agent.utils.task_manager import TaskManager from ..config.config import Config, ConfigLifecycleHandler from .base import Agent logger = get_logger() +_MISSING_ENABLE_SNAPSHOTS = object() + class LLMAgent(Agent): """ @@ -84,6 +86,37 @@ class LLMAgent(Agent): DEFAULT_MAX_CHAT_ROUND = 20 + @staticmethod + def _coerce_enable_snapshots_value(value: Any) -> bool: + if isinstance(value, str): + return value.strip().lower() in ('1', 'true', 'yes', 'on') + return bool(value) + + @staticmethod + def resolve_enable_snapshots(config: Any) -> bool: + """Resolve whether to take automatic pre-task snapshots. + + Tool-spawned sub-agents (``ms_agent_subagent`` in config) default to + ``False``; all other agents default to ``True``. An explicit + ``enable_snapshots`` in config always wins (including string forms + like ``\"false\"`` coerced to boolean). + """ + if OmegaConf.is_config(config): + raw = OmegaConf.select( + config, 'enable_snapshots', default=_MISSING_ENABLE_SNAPSHOTS) + if raw is not _MISSING_ENABLE_SNAPSHOTS and raw is not None: + return LLMAgent._coerce_enable_snapshots_value(raw) + sub = bool( + OmegaConf.select(config, 'ms_agent_subagent', default=False)) + return not sub + if isinstance(config, dict): + if 'enable_snapshots' in config and config[ + 'enable_snapshots'] is not None: + return LLMAgent._coerce_enable_snapshots_value( + config['enable_snapshots']) + return not bool(config.get('ms_agent_subagent')) + return True + TOTAL_PROMPT_TOKENS = 0 TOTAL_COMPLETION_TOKENS = 0 TOTAL_CACHED_TOKENS = 0 @@ -105,9 +138,9 @@ def __init__( super().__init__(config, tag, trust_remote_code) self.callbacks: List[Callback] = [] self.tool_manager: Optional[ToolManager] = None + self.task_manager: Optional[TaskManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None - self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None self.max_chat_round: int = 0 @@ -162,7 +195,8 @@ def _ensure_auto_skills(self) -> bool: # Check sandbox requirements use_sandbox = getattr(skills_config, 'use_sandbox', True) if use_sandbox: - from ms_agent.utils.docker_utils import is_docker_daemon_running + from ms_agent.utils.docker_utils import \ + is_docker_daemon_running if not is_docker_daemon_running(): logger.warning( @@ -351,6 +385,24 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: return messages + def rollback(self, commit_hash: str) -> bool: + """Restore output_dir to snapshot and truncate message history.""" + from ms_agent.utils.snapshot import restore_snapshot + ok, message_count = restore_snapshot(self.output_dir, commit_hash) + if not ok: + return False + # Truncate saved history to the message count at snapshot time + _, saved_messages = read_history(self.output_dir, self.tag) + if saved_messages and message_count < len(saved_messages): + save_history(self.output_dir, self.tag, self.config, + saved_messages[:message_count]) + # Clear read cache on FileSystemTool so stale entries don't block edits + if self.tool_manager is not None: + for tool in self.tool_manager.extra_tools: + if hasattr(tool, '_read_cache'): + tool._read_cache.clear() + return True + def register_callback(self, callback: Callback): """ Register a new callback to be triggered during the agent's lifecycle. @@ -477,6 +529,18 @@ def register_callback_from_config(self): async def on_task_begin(self, messages: List[Message]): self.log_output(f'Agent {self.tag} task beginning.') + if self.resolve_enable_snapshots(self.config): + _user_content = next( + ((getattr(m, 'content', '') or '')[:80] + for m in messages if getattr(m, 'role', '') == 'user'), + '', + ) + take_snapshot( + self.output_dir, + f'[pre] {_user_content}' + if _user_content else '[pre] new task', + message_count=len(messages), + ) await self.loop_callback('on_task_begin', messages) async def on_task_end(self, messages: List[Message]): @@ -528,6 +592,7 @@ async def parallel_tool_call(self, tool_call_id=tool_call_query['id'], name=tool_call_query['tool_name'], resources=tool_call_result_format.resources, + tool_detail=tool_call_result_format.tool_detail, ) if _new_message.tool_call_id is None: @@ -540,6 +605,7 @@ async def parallel_tool_call(self, async def prepare_tools(self): """Initialize and connect the tool manager.""" + self.task_manager = TaskManager() self.tool_manager = ToolManager( self.config, self.mcp_config, @@ -547,10 +613,16 @@ async def prepare_tools(self): trust_remote_code=self.trust_remote_code, ) await self.tool_manager.connect() + for tool in self.tool_manager.extra_tools: + if hasattr(tool, 'set_task_manager'): + tool.set_task_manager(self.task_manager) async def cleanup_tools(self): """Cleanup resources used by the tool manager.""" - await self.tool_manager.cleanup() + if self.task_manager is not None: + self.task_manager.kill_all() + if self.tool_manager is not None: + await self.tool_manager.cleanup() @property def stream(self): @@ -582,16 +654,41 @@ def reasoning_output(self) -> str: DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) - def _write_reasoning(self, text: str): + _THINKING_SEP = '─' * 40 + + def _reasoning_stream(self): + if self.reasoning_output.lower() == 'stdout': + return sys.stdout + return sys.stderr + + def _write_reasoning(self, text: str, dim: bool = False): if not text: return - if self.reasoning_output.lower() == 'stdout': - sys.stdout.write(text) - sys.stdout.flush() + stream = self._reasoning_stream() + use_ansi = hasattr(stream, 'isatty') and stream.isatty() + if dim and use_ansi: + text = f'\033[2m{text}\033[0m' + stream.write(text) + stream.flush() + + def _write_thinking_header(self): + stream = self._reasoning_stream() + use_ansi = hasattr(stream, 'isatty') and stream.isatty() + line = f'{self._THINKING_SEP[:15]} thinking {self._THINKING_SEP[25:]}' + if use_ansi: + stream.write(f'\033[2m{line}\033[0m\n') + else: + stream.write(line + '\n') + stream.flush() + + def _write_thinking_footer(self): + stream = self._reasoning_stream() + use_ansi = hasattr(stream, 'isatty') and stream.isatty() + if use_ansi: + stream.write(f'\n\033[2m{self._THINKING_SEP}\033[0m\n') else: - # default: stderr - sys.stderr.write(text) - sys.stderr.flush() + stream.write(f'\n{self._THINKING_SEP}\n') + stream.flush() @property def system(self): @@ -636,11 +733,7 @@ async def create_messages( return messages async def do_rag(self, messages: List[Message]): - """Process RAG or knowledge search to enrich the user query with context. - - This method handles both traditional RAG and sirchmunk-based knowledge search. - For knowledge search, it also populates searching_detail and search_result - fields in the message for frontend display and next-turn LLM context. + """Process RAG to enrich the user query with context. Args: messages (List[Message]): The message list to process. @@ -654,23 +747,6 @@ async def do_rag(self, messages: List[Message]): # Handle traditional RAG if self.rag is not None: user_message.content = await self.rag.query(query) - # Handle sirchmunk knowledge search - if self.knowledge_search is not None: - # Perform search and get results - search_result = await self.knowledge_search.query(query) - search_details = self.knowledge_search.get_search_details() - - # Store search details in the message for frontend display - user_message.searching_detail = search_details - user_message.search_result = search_result - - # Build enriched context from search results - if search_result: - # Append search context to user query - context = search_result - user_message.content = ( - f'Relevant context retrieved from codebase search:\n\n{context}\n\n' - f'User question: {query}') async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: @@ -757,18 +833,6 @@ async def prepare_rag(self): f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) - async def prepare_knowledge_search(self): - """Load and initialize the knowledge search component from the config.""" - if self.knowledge_search is not None: - # Already initialized (e.g. by caller before run_loop), skip to avoid - # overwriting a configured instance (e.g. one with streaming callbacks set). - return - if hasattr(self.config, 'knowledge_search'): - ks_config = self.config.knowledge_search - if ks_config is not None: - self.knowledge_search: SirchmunkSearch = SirchmunkSearch( - self.config) - async def condense_memory(self, messages: List[Message]) -> List[Message]: """ Update memory using the current conversation history. @@ -836,6 +900,18 @@ def handle_new_response(self, messages: List[Message], and response_message.tool_calls): messages[-1].content = 'Let me do a tool calling.' + def _append_task_notifications(self, + messages: List[Message]) -> List[Message]: + """Inject drained TaskManager completion notices as a user message.""" + if self.task_manager is None: + return messages + notes = self.task_manager.drain_notifications() + if not notes: + return messages + body = '[Background task updates]\n' + '\n'.join(notes) + messages.append(Message(role='user', content=body)) + return messages + @async_retry(max_attempts=Agent.retry_count, delay=1.0) async def step( self, messages: List[Message] @@ -863,6 +939,7 @@ async def step( List[Message]: Updated message history after this step. """ messages = deepcopy(messages) + messages = self._append_task_notifications(messages) if (not self.load_cache) or messages[-1].role != 'assistant': messages = await self.condense_memory(messages) await self.on_generate_response(messages) @@ -875,13 +952,13 @@ async def step( is_first = True _response_message = None _printed_reasoning_header = False + _printed_reasoning_footer = False for _response_message in self.llm.generate( messages, tools=tools): if is_first: messages.append(_response_message) is_first = False - # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: reasoning_text = ( getattr(_response_message, 'reasoning_content', '') @@ -892,19 +969,33 @@ async def step( new_reasoning = reasoning_text[len(_reasoning):] if new_reasoning: if not _printed_reasoning_header: - self._write_reasoning('[thinking]:\n') + self._write_thinking_header() _printed_reasoning_header = True - self._write_reasoning(new_reasoning) + self._write_reasoning(new_reasoning, dim=True) _reasoning = reasoning_text new_content = _response_message.content[len(_content):] - sys.stdout.write(new_content) - sys.stdout.flush() + if new_content: + if _printed_reasoning_header and not _printed_reasoning_footer: + self._write_thinking_footer() + _printed_reasoning_footer = True + sys.stdout.write(new_content) + sys.stdout.flush() _content = _response_message.content messages[-1] = _response_message yield messages - if self.show_reasoning and _printed_reasoning_header: - self._write_reasoning('\n') + if _printed_reasoning_header and not _printed_reasoning_footer: + self._write_thinking_footer() + + # Handle reasoning summaries that arrive after content + if self.show_reasoning and _response_message is not None: + final_reasoning = getattr(_response_message, + 'reasoning_content', '') or '' + if final_reasoning and not _printed_reasoning_header: + self._write_thinking_header() + self._write_reasoning(final_reasoning, dim=True) + self._write_thinking_footer() + sys.stdout.write('\n') else: _response_message = self.llm.generate(messages, tools=tools) @@ -913,9 +1004,9 @@ async def step( getattr(_response_message, 'reasoning_content', '') or '') if reasoning_text: - self._write_reasoning('[thinking]:\n') - self._write_reasoning(reasoning_text) - self._write_reasoning('\n') + self._write_thinking_header() + self._write_reasoning(reasoning_text, dim=True) + self._write_thinking_footer() if _response_message.content: self.log_output('[assistant]:') self.log_output(_response_message.content) @@ -1111,9 +1202,13 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_tools() await self.load_memory() await self.prepare_rag() - await self.prepare_knowledge_search() self.runtime.tag = self.tag + self.task_manager = TaskManager() + for tool in self.tool_manager.extra_tools: + if hasattr(tool, 'set_task_manager'): + tool.set_task_manager(self.task_manager) + if messages is None: messages = self.query @@ -1142,6 +1237,12 @@ async def run_loop(self, messages: Union[List[Message], str], self.log_output('[' + message.role + ']:') self.log_output(message.content) while not self.runtime.should_stop: + if self.task_manager is not None: + notifications = self.task_manager.drain_notifications() + if notifications: + messages.append( + Message( + role='user', content='\n'.join(notifications))) async for messages in self.step(messages): yield messages self.runtime.round += 1 diff --git a/ms_agent/agent/loader.py b/ms_agent/agent/loader.py index 21b1687a7..48ed74d91 100644 --- a/ms_agent/agent/loader.py +++ b/ms_agent/agent/loader.py @@ -3,12 +3,11 @@ import inspect import os import sys +from omegaconf import DictConfig, OmegaConf from typing import Dict, Optional from ms_agent.config.config import Config from ms_agent.utils.constants import DEFAULT_AGENT_FILE, DEFAULT_TAG -from omegaconf import DictConfig, OmegaConf - from .base import Agent @@ -44,8 +43,8 @@ def build(cls, None) is None and config_dir_or_id is not None: agent_config.local_dir = config_dir_or_id - from .llm_agent import LLMAgent from .code_agent import CodeAgent + from .llm_agent import LLMAgent agent_type = LLMAgent.AGENT_NAME if 'code_file' in kwargs: code_file = kwargs.pop('code_file') diff --git a/ms_agent/callbacks/base.py b/ms_agent/callbacks/base.py index 849fe7069..e8b5f0740 100644 --- a/ms_agent/callbacks/base.py +++ b/ms_agent/callbacks/base.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import List from ms_agent.agent.runtime import Runtime from ms_agent.llm.utils import Message -from omegaconf import DictConfig class Callback: diff --git a/ms_agent/callbacks/input_callback.py b/ms_agent/callbacks/input_callback.py index e44db1e31..4017681fd 100644 --- a/ms_agent/callbacks/input_callback.py +++ b/ms_agent/callbacks/input_callback.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import List from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/capabilities/__init__.py b/ms_agent/capabilities/__init__.py index 73feb864e..c5717be55 100644 --- a/ms_agent/capabilities/__init__.py +++ b/ms_agent/capabilities/__init__.py @@ -23,6 +23,7 @@ """ from __future__ import annotations + from typing import Any from ms_agent.capabilities.descriptor import CapabilityDescriptor @@ -39,13 +40,9 @@ def create_registry(config: Any = None) -> CapabilityRegistry: """ registry = CapabilityRegistry() - from ms_agent.capabilities.wrappers import ( - agent_delegate, - deep_research, - filesystem, - lsp_code_server, - web_search, - ) + from ms_agent.capabilities.wrappers import (agent_delegate, deep_research, + filesystem, lsp_code_server, + web_search) filesystem.register_all(registry, config) lsp_code_server.register_all(registry, config) diff --git a/ms_agent/capabilities/mcp_server.py b/ms_agent/capabilities/mcp_server.py index 650dc33cf..84e201f1e 100644 --- a/ms_agent/capabilities/mcp_server.py +++ b/ms_agent/capabilities/mcp_server.py @@ -1,11 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import argparse +import json import logging import os import sys - -import json from dotenv import find_dotenv, load_dotenv + from ms_agent.capabilities import create_registry logger = logging.getLogger(__name__) diff --git a/ms_agent/cli/app.py b/ms_agent/cli/app.py index d9a690d7d..e9d94d281 100644 --- a/ms_agent/cli/app.py +++ b/ms_agent/cli/app.py @@ -55,13 +55,15 @@ def define_args(parsers: argparse.ArgumentParser): def execute(self): if self.args.app_type == 'doc_research': - from ms_agent.app.doc_research import launch_server as launch_doc_research + from ms_agent.app.doc_research import \ + launch_server as launch_doc_research launch_doc_research( server_name=self.args.server_name, server_port=self.args.server_port, share=self.args.share) elif self.args.app_type == 'fin_research': - from ms_agent.app.fin_research import launch_server as launch_fin_research + from ms_agent.app.fin_research import \ + launch_server as launch_fin_research launch_fin_research( server_name=self.args.server_name, server_port=self.args.server_port, diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index 67ec8d563..1bb1d67ad 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -3,13 +3,12 @@ import asyncio import os from importlib import resources as importlib_resources +from omegaconf import OmegaConf from ms_agent.config import Config from ms_agent.config.env import Env from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII -from omegaconf import OmegaConf - from .base import CLICommand logger = get_logger() @@ -48,6 +47,22 @@ class RunCMD(CLICommand): def __init__(self, args): self.args = args + def load_env_file(self): + """Load environment variables from .env file in current directory.""" + env_file = os.path.join(os.getcwd(), '.env') + if os.path.exists(env_file): + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#') and '=' in line: + key, value = line.split('=', 1) + key = key.strip() + value = value.strip() + # Only set if not already set in environment + if key not in os.environ: + os.environ[key] = value + logger.debug(f'Loaded {key} from .env file') + @staticmethod def define_args(parsers: argparse.ArgumentParser): """Define args for run command.""" @@ -136,9 +151,7 @@ def define_args(parsers: argparse.ArgumentParser): required=False, type=str, default=None, - help= - 'Comma-separated list of paths for knowledge search. When provided, enables SirchmunkSearch using LLM config from llm module.' - ) + help='Comma-separated list of paths for knowledge search.') parser.set_defaults(func=subparser_func) def execute(self): @@ -170,7 +183,6 @@ def execute(self): def _execute_with_config(self): Env.load_dotenv_into_environ(getattr(self.args, 'env', None)) - if not self.args.config: current_dir = os.getcwd() if os.path.exists(os.path.join(current_dir, AGENT_CONFIG_FILE)): @@ -218,31 +230,28 @@ def _execute_with_config(self): config = Config.from_task(self.args.config) - # If knowledge_search_paths is provided, configure SirchmunkSearch + # If knowledge_search_paths is provided, configure tools.localsearch if getattr(self.args, 'knowledge_search_paths', None): paths = [ p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip() ] if paths: - if 'knowledge_search' not in config or not config.knowledge_search: - # No existing knowledge_search config, create minimal config - # LLM settings will be auto-reused from llm module by SirchmunkSearch - knowledge_search_config = { - 'name': 'SirchmunkSearch', + if not hasattr(config, 'tools') or config.tools is None: + config['tools'] = OmegaConf.create({}) + tl = getattr(config.tools, 'localsearch', None) + if tl is None or not OmegaConf.is_config(tl): + localsearch_config = { 'paths': paths, 'work_path': './.sirchmunk', 'mode': 'FAST', } - config['knowledge_search'] = OmegaConf.create( - knowledge_search_config) + config.tools['localsearch'] = OmegaConf.create( + localsearch_config) else: - # Existing knowledge_search config found, only update paths - # LLM settings are already handled by SirchmunkSearch internally - existing = OmegaConf.to_container( - config.knowledge_search, resolve=True) + existing = OmegaConf.to_container(tl, resolve=True) existing['paths'] = paths - config['knowledge_search'] = OmegaConf.create(existing) + config.tools['localsearch'] = OmegaConf.create(existing) if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader diff --git a/ms_agent/config/config.py b/ms_agent/config/config.py index 2f6175524..c1f1ab85c 100644 --- a/ms_agent/config/config.py +++ b/ms_agent/config/config.py @@ -3,14 +3,13 @@ import os.path from abc import abstractmethod from copy import deepcopy -from typing import Any, Dict, Union - -from ms_agent.prompting import apply_prompt_files -from ms_agent.utils import get_logger from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.basecontainer import BaseContainer +from typing import Any, Dict, Union from modelscope import snapshot_download +from ms_agent.prompting import apply_prompt_files +from ms_agent.utils import get_logger from ..utils.constants import TOOL_PLUGIN_NAME from .env import Env diff --git a/ms_agent/config/env.py b/ms_agent/config/env.py index 83553254c..39a03e9ef 100644 --- a/ms_agent/config/env.py +++ b/ms_agent/config/env.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from copy import copy -from typing import Dict, Optional - from dotenv import load_dotenv +from typing import Dict, Optional class Env: diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py index 33362beee..8746f27c9 100644 --- a/ms_agent/knowledge_search/__init__.py +++ b/ms_agent/knowledge_search/__init__.py @@ -1,11 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Knowledge search module based on sirchmunk. +"""Backward-compatible re-exports for sirchmunk local search. -This module provides integration between sirchmunk's AgenticSearch -and the ms_agent framework, enabling intelligent codebase search -capabilities similar to RAG. +Implementation lives in :mod:`ms_agent.tools.search.sirchmunk_search`; prefer +importing ``SirchmunkSearch`` from there in new code. """ -from .sirchmunk_search import SirchmunkSearch +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch __all__ = ['SirchmunkSearch'] diff --git a/ms_agent/llm/anthropic_llm.py b/ms_agent/llm/anthropic_llm.py index 6f93a1167..e229b5140 100644 --- a/ms_agent/llm/anthropic_llm.py +++ b/ms_agent/llm/anthropic_llm.py @@ -1,11 +1,118 @@ +import httpx import inspect +import json +from omegaconf import DictConfig, OmegaConf from typing import Any, Dict, Generator, Iterator, List, Optional, Union from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall from ms_agent.utils import assert_package_exist, retry from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig, OmegaConf + + +class _SSEEventInjector(httpx.SyncByteStream): + """Injects SSE ``event:`` lines into DashScope's streaming response. + + DashScope only emits ``data:`` lines in its SSE stream. The Anthropic + SDK's ``MessageStream`` relies on ``event:`` lines to route events. + This wrapper extracts the ``type`` from the JSON payload and prepends + the matching ``event:`` line so the SDK can process events correctly. + """ + + def __init__(self, stream): + self._stream = stream + self._buffer = b'' + + def __iter__(self): + for chunk in self._stream: + self._buffer += chunk + while b'\n\n' in self._buffer: + block, self._buffer = self._buffer.split(b'\n\n', 1) + if block.strip(): + yield self._inject(block) + b'\n\n' + if self._buffer.strip(): + yield self._inject(self._buffer) + b'\n\n' + + @staticmethod + def _inject(block: bytes) -> bytes: + for line in block.split(b'\n'): + s = line.strip() + if s.startswith(b'data:'): + try: + t = json.loads(s[5:].strip()).get('type', '') + if t: + return b'event: ' + t.encode() + b'\n' + block + except (json.JSONDecodeError, ValueError): + pass + return block + + def close(self): + if hasattr(self._stream, 'close'): + self._stream.close() + + +class DashScopeAnthropicTransport(httpx.BaseTransport): + """Routes Anthropic SDK requests to DashScope's compatible-mode endpoint. + + DashScope returns Anthropic-format SSE responses for vertex AI Claude models + (e.g. vertex_ai.claude-opus-4-6), but expects requests at + /compatible-mode/v1/chat/completions with a native protocol flag rather than + the standard Anthropic /v1/messages path. This transport transparently + rewrites URL, auth headers, and body so the Anthropic SDK works unmodified. + """ + + def __init__(self, + dashscope_url: str, + api_key: str, + supplier: Optional[str] = None): + self.dashscope_url = dashscope_url + self.api_key = api_key + self.supplier = supplier + self._transport = httpx.HTTPTransport() + + def handle_request(self, request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + is_streaming = bool(body.get('stream')) + + ext = body.setdefault('dashscope_extend_params', {}) + ext['using_native_protocol'] = True + if self.supplier and 'supplier' not in ext: + ext['supplier'] = self.supplier + + new_headers = { + 'content-type': 'application/json', + 'authorization': f'Bearer {self.api_key}', + } + _skip = frozenset({ + 'x-api-key', 'content-type', 'authorization', 'content-length', + 'host', 'transfer-encoding' + }) + for key, value in request.headers.items(): + k = key.lower() + if k not in _skip and not k.startswith('anthropic'): + new_headers[key] = value + + new_content = json.dumps(body).encode('utf-8') + new_request = httpx.Request( + method=request.method, + url=self.dashscope_url, + headers=new_headers, + content=new_content, + extensions=request.extensions, + ) + response = self._transport.handle_request(new_request) + + if is_streaming: + return httpx.Response( + status_code=response.status_code, + headers=response.headers, + stream=_SSEEventInjector(response.stream), + extensions=response.extensions, + ) + return response + + def close(self): + self._transport.close() class Anthropic(LLM): @@ -29,10 +136,31 @@ def __init__( if not api_key: raise ValueError('Anthropic API key is required.') - self.client = anthropic.Anthropic( - api_key=api_key, - base_url=base_url, - ) + self._is_dashscope = bool(base_url and 'dashscope' in base_url.lower()) + + if self._is_dashscope: + dashscope_url = base_url + if not dashscope_url.rstrip('/').endswith('/chat/completions'): + dashscope_url = dashscope_url.rstrip('/') + '/chat/completions' + supplier = config.llm.get('dashscope_supplier', None) + transport = DashScopeAnthropicTransport( + dashscope_url=dashscope_url, + api_key=api_key, + supplier=supplier, + ) + http_client = httpx.Client( + transport=transport, + timeout=httpx.Timeout(300.0, connect=60.0), + ) + self.client = anthropic.Anthropic( + api_key=api_key, + http_client=http_client, + ) + else: + self.client = anthropic.Anthropic( + api_key=api_key, + base_url=base_url, + ) self.args: Dict = OmegaConf.to_container( getattr(config, 'generation_config', DictConfig({}))) @@ -112,24 +240,42 @@ def _call_llm(self, formatted_messages = formatted_messages[1:] max_tokens = kwargs.pop('max_tokens', 16000) - extra_body = kwargs.get('extra_body', {}) - enable_thinking = extra_body.get('enable_thinking', False) - thinking_budget = extra_body.get('thinking_budget', max_tokens) + + enable_thinking = bool(kwargs.pop('enable_thinking', False)) + thinking_budget = kwargs.pop('thinking_budget', None) + thinking_type = kwargs.pop('thinking_type', None) + + raw_extra_body = kwargs.pop('extra_body', {}) or {} + extra_body = dict(raw_extra_body) if isinstance(raw_extra_body, + dict) else {} + enable_thinking = bool( + extra_body.pop('enable_thinking', enable_thinking)) + thinking_budget = extra_body.pop('thinking_budget', + thinking_budget) or max_tokens + thinking_type = extra_body.pop('thinking_type', thinking_type) + for _k in ('show_reasoning', 'reasoning_output'): + extra_body.pop(_k, None) params = { 'model': self.model, 'messages': formatted_messages, - 'max_tokens': max_tokens, - 'thinking': { - 'type': 'enabled' if enable_thinking else 'disabled', - 'budget_tokens': thinking_budget - } + 'max_tokens': max_tokens } + if thinking_type == 'adaptive': + params['thinking'] = {'type': 'adaptive'} + elif enable_thinking: + params['thinking'] = { + 'type': 'enabled', + 'budget_tokens': thinking_budget, + } + if system: params['system'] = system if tools: params['tools'] = tools + if extra_body: + kwargs['extra_body'] = extra_body params.update(kwargs) if stream: diff --git a/ms_agent/llm/dashscope_llm.py b/ms_agent/llm/dashscope_llm.py index b4a6ddaa8..00239286e 100644 --- a/ms_agent/llm/dashscope_llm.py +++ b/ms_agent/llm/dashscope_llm.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import List from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message, Tool from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig class DashScope(OpenAI): diff --git a/ms_agent/llm/deepseek_llm.py b/ms_agent/llm/deepseek_llm.py index e565308bc..3bc6a59d0 100644 --- a/ms_agent/llm/deepseek_llm.py +++ b/ms_agent/llm/deepseek_llm.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import List from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message, Tool -from omegaconf import DictConfig class DeepSeek(OpenAI): diff --git a/ms_agent/llm/llm.py b/ms_agent/llm/llm.py index 72af53467..803434f87 100644 --- a/ms_agent/llm/llm.py +++ b/ms_agent/llm/llm.py @@ -1,11 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import abstractmethod +from omegaconf import DictConfig from typing import Any, Dict, List, Optional from ms_agent.config import Config -from omegaconf import DictConfig - from ..utils.constants import DEFAULT_RETRY_COUNT from .utils import Message, Tool @@ -69,7 +68,7 @@ def from_config(cls, config: DictConfig) -> Any: Returns: The LLM instance. """ - from .model_mapping import all_services_mapping, OpenAI + from .model_mapping import OpenAI, all_services_mapping if config.llm.get('service') in all_services_mapping: return all_services_mapping[config.llm.service](config) else: diff --git a/ms_agent/llm/modelscope_llm.py b/ms_agent/llm/modelscope_llm.py index 7b761c5c0..bd1b54a56 100644 --- a/ms_agent/llm/modelscope_llm.py +++ b/ms_agent/llm/modelscope_llm.py @@ -1,7 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig + from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig class ModelScope(OpenAI): diff --git a/ms_agent/llm/openai.py b/ms_agent/llm/openai.py index 390f1d4ab..cd902119f 100644 --- a/ms_agent/llm/openai.py +++ b/ms_agent/llm/openai.py @@ -1,11 +1,11 @@ # flake8: noqa +import json import uuid +from openai import OpenAI, Stream +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from typing import TYPE_CHECKING, Any, Dict, List, Literal -import json from ms_agent.utils.logger import get_logger -from openai import OpenAI, Stream -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall logger = get_logger() diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index dadc1bf1c..86969210f 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -1,6 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import httpx import inspect +import json from copy import deepcopy +from omegaconf import DictConfig, OmegaConf +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, Function) from typing import Any, Dict, Generator, Iterable, List, Optional from ms_agent.llm import LLM @@ -8,19 +13,39 @@ from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist, get_logger, retry) from ms_agent.utils.constants import get_service_config -from omegaconf import DictConfig, OmegaConf -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, Function) logger = get_logger() +class _DashScopeResponsesTransport(httpx.HTTPTransport): + """Rewrite /v1/responses -> /v1/chat/completions for DashScope proxy. + + DashScope serves the OpenAI Responses protocol on the chat/completions + path rather than the standard /v1/responses path. This transport + transparently rewrites the SDK's outgoing request so that + ``client.responses.create()`` hits the correct DashScope endpoint. + """ + + def handle_request(self, request): + if b'/v1/responses' in request.url.raw_path: + new_path = request.url.raw_path.replace(b'/v1/responses', + b'/v1/chat/completions') + request.url = request.url.copy_with(raw_path=new_path) + return super().handle_request(request) + + class OpenAI(LLM): """Base Class for OpenAI SDK LLMs. This class provides the base implementation for interacting with OpenAI-compatible models, including support for chat completions, streaming responses, and continue generates. + Supports the OpenAI Responses API (``client.responses.create``) when + ``generation_config.use_responses_api`` is ``true``. In this mode, + reasoning summaries are extracted and surfaced through + ``Message.reasoning_content`` so the agent's existing thinking display + works without change. + Args: config (`DictConfig`): The configuration object containing model and generation settings. base_url (`Optional[str]`): Custom base URL for the API endpoint. Defaults to None. @@ -58,6 +83,31 @@ def __init__( self.args: Dict = OmegaConf.to_container( getattr(config, 'generation_config', DictConfig({}))) + # Responses API support + self._use_responses_api = bool( + self.args.get('use_responses_api', False)) + self._responses_client = None + self._responses_state_mode = str( + self.args.get('responses_state_mode', 'stateless')).lower() + if self._responses_state_mode == 'stateful': + self._responses_state_mode = 'previous_response_id' + + if self._use_responses_api: + self._is_dashscope = bool(base_url + and 'dashscope' in base_url.lower()) + if self._is_dashscope: + http_client = httpx.Client( + transport=_DashScopeResponsesTransport(), + timeout=httpx.Timeout(300.0, connect=60.0), + ) + self._responses_client = openai.OpenAI( + api_key=api_key, + base_url=base_url, + http_client=http_client, + ) + else: + self._responses_client = self.client + # Prefix cache configuration # - force_prefix_cache: enable structured content with cache_control for explicit caching # - prefix_cache_roles: which messages to cache (only these are converted to structured format) @@ -176,17 +226,21 @@ def generate(self, Union[Message, Generator[Message, None, None]]: Either a single Message object (non-streaming) or a generator yielding Message chunks (streaming). """ - parameters = inspect.signature( - self.client.chat.completions.create).parameters args = self.args.copy() args.update(kwargs) stream = args.get('stream', False) + if self._use_responses_api: + if stream: + return self._responses_stream_generate(messages, tools, **args) + else: + return self._responses_generate(messages, tools, **args) + + parameters = inspect.signature( + self.client.chat.completions.create).parameters args = {key: value for key, value in args.items() if key in parameters} completion = self._call_llm(messages, self.format_tools(tools), **args) - # Complex task may produce long response - # Call continue_generate to keep generating if the finish_reason is `length` max_continue_runs = max_continue_runs or self.max_continue_runs if stream: return self._stream_continue_generate(messages, completion, tools, @@ -540,6 +594,371 @@ def _continue_generate(self, else: return new_message + def _build_responses_input( + self, messages: List[Message]) -> List[Dict[str, Any]]: + """Convert internal Message list to the ``input`` format expected by + the Responses API. + + Key differences from chat completions format: + - ``system`` role becomes ``developer``. + - ``assistant`` messages with ``tool_calls`` emit the text content + as a normal assistant item, followed by one ``function_call`` + item per tool call. + - ``tool`` role messages become ``function_call_output`` items + (keyed by ``call_id``, not ``role``). + """ + items: List[Dict[str, Any]] = [] + for msg in messages: + if msg.role == 'system': + items.append({ + 'role': 'developer', + 'content': msg.content, + }) + elif msg.role == 'assistant': + if self._responses_state_mode != 'previous_response_id': + # Stateless mode needs explicit passback of opaque reasoning + # items returned by the previous response. + for raw_item in getattr(msg, '_responses_output_items', + []): + items.append(raw_item) + if msg.content and not self._is_responses_tool_placeholder( + msg): + items.append({ + 'role': 'assistant', + 'content': msg.content, + }) + if msg.tool_calls: + for tc in msg.tool_calls: + arguments = tc.get('arguments', '{}') + if not isinstance(arguments, str): + arguments = json.dumps( + arguments, ensure_ascii=False) + items.append({ + 'type': 'function_call', + 'call_id': tc.get('id', ''), + 'name': tc.get('tool_name', ''), + 'arguments': arguments, + }) + elif msg.role == 'tool': + content = msg.content + if not isinstance(content, str): + content = json.dumps(content, ensure_ascii=False) + items.append({ + 'type': 'function_call_output', + 'call_id': msg.tool_call_id or '', + 'output': content, + }) + else: + items.append({ + 'role': msg.role, + 'content': msg.content, + }) + return items + + @staticmethod + def _is_responses_tool_placeholder(message: Message) -> bool: + """Return True for framework-generated assistant placeholder text.""" + return bool(message.tool_calls + ) and message.content == 'Let me do a tool calling.' + + def _prepare_responses_request( + self, messages: List[Message], + args: Dict[str, Any]) -> tuple[List[Message], Dict[str, Any]]: + """Prepare message slice and request args for Responses API calls.""" + request_args = dict(args) + + if self._responses_state_mode != 'previous_response_id': + return messages, request_args + + if request_args.get('previous_response_id'): + return messages, request_args + + for idx in range(len(messages) - 1, -1, -1): + msg = messages[idx] + if msg.role == 'assistant' and msg.id: + request_args['previous_response_id'] = msg.id + return messages[idx + 1:], request_args + + return messages, request_args + + def _build_responses_tools( + self, + tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]: + """Convert internal Tool list to Responses API function tool format.""" + if not tools: + return None + return [{ + 'type': 'function', + 'name': t['tool_name'], + 'description': t.get('description', ''), + 'parameters': t.get('parameters', {}), + } for t in tools] + + def _build_responses_kwargs(self, args: Dict) -> Dict: + """Filter and reshape generation args for ``responses.create``.""" + kwargs: Dict[str, Any] = {} + + reasoning_effort = args.get('reasoning_effort') + reasoning_summary = args.get('reasoning_summary', 'auto') + if reasoning_effort or reasoning_summary: + reasoning: Dict[str, Any] = {} + if reasoning_effort: + reasoning['effort'] = reasoning_effort + if reasoning_summary: + reasoning['summary'] = reasoning_summary + kwargs['reasoning'] = reasoning + + if args.get('temperature') is not None: + kwargs['temperature'] = args['temperature'] + if args.get('top_p') is not None: + kwargs['top_p'] = args['top_p'] + if args.get('max_output_tokens') is not None: + kwargs['max_output_tokens'] = args['max_output_tokens'] + if args.get('stream_options') is not None: + kwargs['stream_options'] = args['stream_options'] + if args.get('previous_response_id') is not None: + kwargs['previous_response_id'] = args['previous_response_id'] + + include = args.get('include') + if include is not None: + kwargs['include'] = include + elif self._responses_state_mode != 'previous_response_id': + # Stateless multi-turn mode needs encrypted reasoning so opaque + # reasoning items can be passed back in subsequent requests. + kwargs['include'] = ['reasoning.encrypted_content'] + + return kwargs + + @staticmethod + def _extract_reasoning_summaries_from_response(response) -> str: + """Pull reasoning summary text from a completed Responses API object.""" + parts: List[str] = [] + for item in getattr(response, 'output', []) or []: + if getattr(item, 'type', None) == 'reasoning': + for summary in getattr(item, 'summary', []) or []: + text = getattr(summary, 'text', None) + if text: + parts.append(text) + return '\n'.join(parts) + + @staticmethod + def _extract_tool_calls_from_response( + response) -> Optional[List[ToolCall]]: + """Extract tool calls from a completed Responses API object.""" + tool_calls: List[ToolCall] = [] + for item in getattr(response, 'output', []) or []: + if getattr(item, 'type', None) == 'function_call': + arguments = getattr(item, 'arguments', '{}') + if not isinstance(arguments, str): + arguments = json.dumps(arguments, ensure_ascii=False) + tool_calls.append( + ToolCall( + id=getattr(item, 'call_id', '') + or getattr(item, 'id', ''), + index=len(tool_calls), + type='function', + tool_name=getattr(item, 'name', ''), + arguments=arguments, + )) + return tool_calls if tool_calls else None + + @staticmethod + def _extract_usage_from_response(response) -> tuple: + """Return (prompt_tokens, completion_tokens) from a Responses API object.""" + usage = getattr(response, 'usage', None) + if usage is None: + return 0, 0 + return ( + getattr(usage, 'input_tokens', 0) or 0, + getattr(usage, 'output_tokens', 0) or 0, + ) + + @staticmethod + def _to_jsonable(value: Any) -> Any: + """Convert SDK objects nested in Responses items into JSON-safe data.""" + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, list): + return [OpenAI._to_jsonable(item) for item in value] + if isinstance(value, dict): + return { + key: OpenAI._to_jsonable(item) + for key, item in value.items() + } + if hasattr(value, 'model_dump'): + return OpenAI._to_jsonable(value.model_dump()) + if hasattr(value, 'to_dict'): + return OpenAI._to_jsonable(value.to_dict()) + return value + + def _collect_passback_items(self, response) -> List[Dict[str, Any]]: + """Collect output items that must be passed back in multi-turn calls. + + Per OpenAI docs, reasoning items returned alongside tool calls must be + included in the next request for reasoning models. + """ + items: List[Dict[str, Any]] = [] + for item in getattr(response, 'output', []) or []: + item_type = getattr(item, 'type', None) + if item_type == 'reasoning': + passback_item: Dict[str, Any] = { + 'type': + 'reasoning', + 'summary': + self._to_jsonable(getattr(item, 'summary', []) or []), + } + encrypted_content = getattr(item, 'encrypted_content', None) + if encrypted_content: + passback_item['encrypted_content'] = encrypted_content + if not getattr(self, '_is_dashscope', False): + item_id = getattr(item, 'id', None) + if item_id: + passback_item['id'] = item_id + items.append(passback_item) + return items + + def _responses_generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **args) -> Message: + """Non-streaming Responses API call.""" + request_messages, request_args = self._prepare_responses_request( + messages, args) + input_items = self._build_responses_input(request_messages) + resp_tools = self._build_responses_tools(tools) + kwargs = self._build_responses_kwargs(request_args) + if resp_tools: + kwargs['tools'] = resp_tools + + response = self._responses_client.responses.create( + model=self.model, + input=input_items, + **kwargs, + ) + text = getattr(response, 'output_text', '') or '' + reasoning = self._extract_reasoning_summaries_from_response(response) + resp_tool_calls = self._extract_tool_calls_from_response(response) + prompt_tokens, completion_tokens = self._extract_usage_from_response( + response) + passback = self._collect_passback_items(response) + + return Message( + role='assistant', + content=text, + reasoning_content=reasoning, + tool_calls=resp_tool_calls, + _responses_output_items=passback, + id=getattr(response, 'id', ''), + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + @staticmethod + def _extract_reasoning_from_item(item) -> str: + """Extract reasoning summary text from a single output item.""" + parts: List[str] = [] + for summary in getattr(item, 'summary', []) or []: + text = getattr(summary, 'text', None) + if text: + parts.append(text) + return '\n'.join(parts) + + def _responses_stream_generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + **args) -> Generator[Message, None, None]: + """Streaming Responses API call. + + Yields incremental ``Message`` objects. Reasoning summaries are + extracted from ``response.output_item.done`` events (type=reasoning) + which arrive *before* the first text delta, so the agent layer can + display the thinking header before content begins streaming. + """ + request_messages, request_args = self._prepare_responses_request( + messages, args) + input_items = self._build_responses_input(request_messages) + resp_tools = self._build_responses_tools(tools) + kwargs = self._build_responses_kwargs(request_args) + if resp_tools: + kwargs['tools'] = resp_tools + + stream = self._responses_client.responses.create( + model=self.model, + input=input_items, + stream=True, + **kwargs, + ) + + current_message = Message( + role='assistant', + content='', + reasoning_content='', + ) + streamed_text = '' + final_response = None + response_error_msg = '' + reasoning_parts: List[str] = [] + + for event in stream: + event_type = getattr(event, 'type', '') + + if event_type == 'response.output_item.done': + item = getattr(event, 'item', None) + if item and getattr(item, 'type', None) == 'reasoning': + summary_text = self._extract_reasoning_from_item(item) + if summary_text: + reasoning_parts.append(summary_text) + current_message.reasoning_content = '\n'.join( + reasoning_parts) + yield current_message + + elif event_type == 'response.output_text.delta': + delta = getattr(event, 'delta', '') + if delta: + streamed_text += delta + current_message.content = streamed_text + yield current_message + + elif event_type == 'response.output_text.done': + done_text = getattr(event, 'text', '') + if done_text and not streamed_text: + streamed_text = done_text + current_message.content = streamed_text + yield current_message + + elif event_type == 'response.completed': + final_response = getattr(event, 'response', None) + + elif event_type == 'response.failed': + failed_response = getattr(event, 'response', None) + failed_error = getattr(failed_response, 'error', None) + response_error_msg = getattr(failed_error, 'message', + '') or str(failed_error) + + if final_response: + if not reasoning_parts: + reasoning = self._extract_reasoning_summaries_from_response( + final_response) + if reasoning: + current_message.reasoning_content = reasoning + resp_tool_calls = self._extract_tool_calls_from_response( + final_response) + if resp_tool_calls: + current_message.tool_calls = resp_tool_calls + passback = self._collect_passback_items(final_response) + if passback: + current_message._responses_output_items = passback + prompt_tokens, completion_tokens = self._extract_usage_from_response( + final_response) + current_message.prompt_tokens = prompt_tokens + current_message.completion_tokens = completion_tokens + current_message.id = getattr(final_response, 'id', '') + yield current_message + elif response_error_msg: + logger.error(f'Responses API failed: {response_error_msg}') + raise RuntimeError( + f'Responses API call failed: {response_error_msg}') + def _format_input_message(self, messages: List[Message]) -> List[Dict[str, Any]]: """Converts a list of Message objects into the format expected by the OpenAI API. diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 410aa12f0..12f66af24 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -1,8 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Optional, Union - -import json from typing_extensions import Literal, Required, TypedDict @@ -40,6 +39,10 @@ class Message: # needed for output reasoning_content: str = '' + # Opaque output items from the Responses API that must be passed back + # in multi-turn tool-calling conversations (e.g. reasoning items). + _responses_output_items: List[Dict[str, Any]] = field(default_factory=list) + # request id id: str = '' @@ -61,11 +64,8 @@ class Message: api_calls: int = 1 - # Knowledge search (sirchmunk) related fields - # searching_detail: Search process logs and metadata for frontend display - searching_detail: Dict[str, Any] = field(default_factory=dict) - # search_result: Raw search results to be used as context for next LLM turn - search_result: List[Dict[str, Any]] = field(default_factory=list) + # role=tool: extra payload for UIs / SSE only; omitted from LLM API via to_dict_clean(). + tool_detail: Optional[str] = None def to_dict(self): return asdict(self) @@ -88,7 +88,16 @@ def to_dict_clean(self): } } required = ['content', 'role'] - rm = ['completion_tokens', 'prompt_tokens', 'api_calls'] + # Never send UI-only fields to model providers. + rm = [ + 'completion_tokens', + 'prompt_tokens', + 'api_calls', + 'tool_detail', + 'searching_detail', + 'search_result', + '_responses_output_items', + ] return { key: value for key, value in raw_dict.items() @@ -98,20 +107,33 @@ def to_dict_clean(self): @dataclass class ToolResult: + """Tool execution outcome. + + ``text`` is sent to the model as the tool message ``content``. + ``tool_detail`` is optional verbose output for frontends only (SSE, logs). + """ + text: str resources: List[str] = field(default_factory=list) extra: dict = field(default_factory=dict) + tool_detail: Optional[str] = None @staticmethod def from_raw(raw): if isinstance(raw, str): return ToolResult(text=raw) if isinstance(raw, dict): + model_text = raw.get('result') + if model_text is None: + model_text = raw.get('text', '') + td = raw.get('tool_detail') return ToolResult( - text=str(raw.get('text', '')), + text=str(model_text), resources=raw.get('resources', []), + tool_detail=None if td is None else str(td), extra={ k: v - for k, v in raw.items() if k not in ['text', 'resources'] + for k, v in raw.items() + if k not in ['text', 'resources', 'result', 'tool_detail'] }) raise TypeError('tool_call_result must be str or dict') diff --git a/ms_agent/memory/base.py b/ms_agent/memory/base.py index fde42fb56..72baec529 100644 --- a/ms_agent/memory/base.py +++ b/ms_agent/memory/base.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from abc import ABC, abstractmethod +from omegaconf import DictConfig from typing import List from ms_agent.llm.utils import Message from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR -from omegaconf import DictConfig class Memory(ABC): diff --git a/ms_agent/memory/condenser/code_condenser.py b/ms_agent/memory/condenser/code_condenser.py index 12fe626e5..723b5eebe 100644 --- a/ms_agent/memory/condenser/code_condenser.py +++ b/ms_agent/memory/condenser/code_condenser.py @@ -1,7 +1,7 @@ +import json import os from typing import List -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory from ms_agent.utils import get_logger diff --git a/ms_agent/memory/condenser/context_compressor.py b/ms_agent/memory/condenser/context_compressor.py index 9bec9bcf1..08ef02755 100644 --- a/ms_agent/memory/condenser/context_compressor.py +++ b/ms_agent/memory/condenser/context_compressor.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json from typing import List, Optional -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory from ms_agent.utils.logger import logger diff --git a/ms_agent/memory/condenser/refine_condenser.py b/ms_agent/memory/condenser/refine_condenser.py index 557e2a1ff..6d7ec804e 100644 --- a/ms_agent/memory/condenser/refine_condenser.py +++ b/ms_agent/memory/condenser/refine_condenser.py @@ -1,6 +1,6 @@ +import json from typing import List -import json from ms_agent.llm import LLM, Message from ms_agent.memory import Memory diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index b087a2dd9..d84826de4 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -2,23 +2,23 @@ import asyncio import hashlib import importlib +import json +import json5 import os import re import traceback from datetime import datetime from functools import partial, wraps from inspect import signature +from omegaconf import DictConfig, OmegaConf from typing import Any, Dict, List, Optional, Tuple -import json -import json5 from ms_agent.llm.utils import Message from ms_agent.memory import Memory from ms_agent.utils import get_fact_retrieval_prompt from ms_agent.utils.constants import (DEFAULT_OUTPUT_DIR, DEFAULT_SEARCH_LIMIT, DEFAULT_USER, get_service_config) from ms_agent.utils.logger import logger -from omegaconf import DictConfig, OmegaConf class MemoryMapping: diff --git a/ms_agent/memory/diversity.py b/ms_agent/memory/diversity.py index 0da0be038..2e307a0ac 100644 --- a/ms_agent/memory/diversity.py +++ b/ms_agent/memory/diversity.py @@ -1,12 +1,11 @@ +import asyncio import re from copy import deepcopy +from omegaconf import DictConfig from typing import List from ms_agent.utils import get_logger -from omegaconf import DictConfig - from ..llm import LLM, Message -from ..tools import SplitTask from .base import Memory logger = get_logger() @@ -58,7 +57,6 @@ class Diversity(Memory): def __init__(self, config): super().__init__(config) self.llm = None - self.split_task = None self.num_split = 5 self.memory_called = False _config = deepcopy(config) @@ -67,9 +65,43 @@ def __init__(self, config): delattr(_config, 'tools') _config.generation_config.temperature = 1.0 self.llm = LLM.from_config(_config) - self.split_task = SplitTask(_config, tag_prefix='diversity-') + self._sub_config = _config self.num_split = getattr(config, 'num_split', self.num_split) + async def _run_tasks_sequential(self, tasks: list) -> str: + """Run a list of {system, query} tasks sequentially using LLMAgent.""" + from ms_agent.agent import LLMAgent + res = [] + for i, task in enumerate(tasks): + system = task.get('system', '') + query = task.get('query', '') + cfg = deepcopy(self._sub_config) + if not hasattr(cfg, 'prompt'): + cfg.prompt = DictConfig({}) + cfg.prompt.system = system + agent = LLMAgent( + config=cfg, + trust_remote_code=getattr(cfg, 'trust_remote_code', False), + tag=f'{getattr(cfg, "tag", "agent")}-diversity-{i}', + load_cache=False, + ) + try: + messages = await agent.run(query) + if isinstance(messages, list) and messages: + content = messages[-1].content or '' + else: + content = str(messages) + except Exception as e: + content = f'SubTask{i} failed with error: {e}' + res.append(content) + + formatted = '' + for i, content in enumerate(res): + if len(content) > 2048: + content = content[:2048] + formatted += f'SplitTask{i}:{content}\n' + return formatted + async def run(self, messages: List[Message]): if self.memory_called: return messages @@ -93,13 +125,7 @@ async def run(self, messages: List[Message]): } arguments.append(inputs) - arguments = { - 'tasks': arguments, - 'execution_mode': 'sequential', - } - - results = await self.split_task.call_tool( - '', tool_name='', tool_args=arguments) + results = await self._run_tasks_sequential(arguments) pattern = r'(.*?)' all_keywords = [] for keywords in re.findall(pattern, results, re.DOTALL): @@ -118,13 +144,7 @@ async def run(self, messages: List[Message]): } arguments.append(inputs) - arguments = { - 'tasks': arguments, - 'execution_mode': 'sequential', - } - - results = await self.split_task.call_tool( - '', tool_name='', tool_args=arguments) + results = await self._run_tasks_sequential(arguments) pattern = r'(.*?)' all_keywords = [] for keywords in re.findall(pattern, results, re.DOTALL): diff --git a/ms_agent/memory/memory_manager.py b/ms_agent/memory/memory_manager.py index 5a203505d..5d0a10210 100644 --- a/ms_agent/memory/memory_manager.py +++ b/ms_agent/memory/memory_manager.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import Dict from ms_agent.memory import Memory, memory_mapping from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_USER -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/prompting/file_resolver.py b/ms_agent/prompting/file_resolver.py index c04b8bdde..2bf340200 100644 --- a/ms_agent/prompting/file_resolver.py +++ b/ms_agent/prompting/file_resolver.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from dataclasses import dataclass -from typing import List, Optional, Tuple - from omegaconf import DictConfig +from typing import List, Optional, Tuple @dataclass(frozen=True) diff --git a/ms_agent/rag/extraction.py b/ms_agent/rag/extraction.py index c62b10067..0215d8b7b 100644 --- a/ms_agent/rag/extraction.py +++ b/ms_agent/rag/extraction.py @@ -1,11 +1,11 @@ # flake8: noqa # yapf: disable from abc import ABC, abstractmethod -from typing import Any, Dict, List - from docling_core.transforms.chunker import BaseChunk from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem, DocItemLabel +from typing import Any, Dict, List + from ms_agent.rag.schema import KeyInformation from ms_agent.tools.docling.chunker import HybridDocumentChunker from ms_agent.tools.docling.doc_loader import DocLoader diff --git a/ms_agent/rag/llama_index_rag.py b/ms_agent/rag/llama_index_rag.py index e4535d38d..129fda95f 100644 --- a/ms_agent/rag/llama_index_rag.py +++ b/ms_agent/rag/llama_index_rag.py @@ -1,11 +1,10 @@ import os import shutil -from typing import Any, List, Optional - -from ms_agent.utils import assert_package_exist from omegaconf import DictConfig +from typing import Any, List, Optional from modelscope import snapshot_download +from ms_agent.utils import assert_package_exist from ..llm import LLM, Message from .base import RAG @@ -41,6 +40,7 @@ def __init__(self, config: DictConfig): from llama_index.core import Settings from llama_index.core.node_parser import SentenceSplitter + # Set node parser Settings.node_parser = SentenceSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) @@ -49,10 +49,10 @@ def __init__(self, config: DictConfig): if self.retrieve_only: Settings.llm = None else: + from llama_index.core.base.llms.types import (CompletionResponse, + LLMMetadata) from llama_index.core.llms import CustomLLM - from llama_index.core.base.llms.types import LLMMetadata from llama_index.core.llms.callbacks import llm_completion_callback - from llama_index.core.base.llms.types import CompletionResponse self._llm_instance = LLM.from_config(self.config) class MSCustomLLM(CustomLLM): @@ -108,7 +108,7 @@ def _validate_config(self, config: DictConfig): raise ValueError('chunk_size must be greater than 0') def _setup_embedding_model(self, config: DictConfig): - from llama_index.core import (Settings) + from llama_index.core import Settings from llama_index.embeddings.huggingface import HuggingFaceEmbedding try: use_hf = getattr(config, 'use_huggingface', False) @@ -124,7 +124,7 @@ def _setup_embedding_model(self, config: DictConfig): async def add_documents(self, documents: List[str]): if not documents: raise ValueError('Document list cannot be empty') - from llama_index.core import (Document, VectorStoreIndex) + from llama_index.core import Document, VectorStoreIndex docs = [Document(text=doc) for doc in documents] self.index = VectorStoreIndex.from_documents(docs) if not self.retrieve_only: @@ -159,6 +159,7 @@ async def _setup_query_engine(self): return from llama_index.core import Settings + # Check if LLM is set if Settings.llm is None and not self.retrieve_only: return @@ -239,10 +240,11 @@ async def hybrid_search(self, query: str, top_k: int = 5) -> List[dict]: return [] from llama_index.core.retrievers import VectorIndexRetriever + # Try to import BM25 related modules try: - from llama_index.retrievers.bm25 import BM25Retriever from llama_index.core.retrievers import QueryFusionRetriever + from llama_index.retrievers.bm25 import BM25Retriever bm25_available = True except ImportError: bm25_available = False @@ -316,7 +318,7 @@ async def load_index(self, persist_dir: Optional[str] = None): raise FileNotFoundError( f'Index directory does not exist: {load_dir}') - from llama_index.core import (StorageContext, load_index_from_storage) + from llama_index.core import StorageContext, load_index_from_storage storage_context = StorageContext.from_defaults(persist_dir=load_dir) self.index = load_index_from_storage(storage_context) diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index e66da954d..cf52aaa7c 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -5,5 +5,5 @@ 'LlamaIndexRAG': LlamaIndexRAG, } -# Note: SirchmunkSearch is registered in knowledge_search module -# and integrated directly in LLMAgent, not through rag_mapping +# Note: Sirchmunk local search is the ``localsearch`` tool +# (ms_agent.tools.search); it is not wired through rag_mapping. diff --git a/ms_agent/retriever/hybrid_retriever.py b/ms_agent/retriever/hybrid_retriever.py index e84bc8398..d97c4fcdc 100644 --- a/ms_agent/retriever/hybrid_retriever.py +++ b/ms_agent/retriever/hybrid_retriever.py @@ -1,11 +1,11 @@ import asyncio +import faiss import math +import numpy as np import os import threading from typing import List, Tuple -import faiss -import numpy as np from ms_agent.utils.tokenizer_util import TokenizerUtil os.environ['OMP_NUM_THREADS'] = '1' diff --git a/ms_agent/sandbox/sandbox.py b/ms_agent/sandbox/sandbox.py index 8a5753761..3f63bf962 100644 --- a/ms_agent/sandbox/sandbox.py +++ b/ms_agent/sandbox/sandbox.py @@ -46,7 +46,7 @@ def __init__(self, **kwargs): super().__init__() self._init() - from ms_enclave.sandbox import SandboxConfig, DockerSandboxConfig + from ms_enclave.sandbox import DockerSandboxConfig, SandboxConfig # Mount host directories into the sandbox container if provided _volumes = kwargs.pop('volumes', None) or [] diff --git a/ms_agent/skill/loader.py b/ms_agent/skill/loader.py index 1f5dca2a7..f6d898730 100644 --- a/ms_agent/skill/loader.py +++ b/ms_agent/skill/loader.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Union from ms_agent.utils.logger import logger - from .schema import SkillSchema, SkillSchemaParser diff --git a/ms_agent/skill/schema.py b/ms_agent/skill/schema.py index 722e0acc4..71a50d01b 100644 --- a/ms_agent/skill/schema.py +++ b/ms_agent/skill/schema.py @@ -6,13 +6,12 @@ Each Skill is represented as a self-contained directory with metadata. """ import re +import yaml from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union -import yaml from ms_agent.utils.logger import logger - from .spec import Spec SUPPORTED_SCRIPT_EXT = ('.py', '.sh', '.js') diff --git a/ms_agent/tools/__init__.py b/ms_agent/tools/__init__.py index 58b26ecee..2dae662c7 100644 --- a/ms_agent/tools/__init__.py +++ b/ms_agent/tools/__init__.py @@ -4,6 +4,6 @@ from .code_server import LSPCodeServer from .filesystem_tool import FileSystemTool from .mcp_client import MCPClient -from .split_task import SplitTask +from .task_control_tool import TaskControlTool from .todolist_tool import TodoListTool from .tool_manager import ToolManager diff --git a/ms_agent/tools/agent_tool.py b/ms_agent/tools/agent_tool.py index 5c86e18c5..554c6a6d6 100644 --- a/ms_agent/tools/agent_tool.py +++ b/ms_agent/tools/agent_tool.py @@ -1,18 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import multiprocessing as mp import os import threading import traceback import uuid from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from omegaconf import DictConfig, ListConfig, OmegaConf from queue import Empty as QueueEmpty from queue import Full as QueueFull from typing import Any, Callable, Dict, List, Optional, Union -import json from ms_agent.agent.loader import AgentLoader from ms_agent.llm.utils import Message, Tool from ms_agent.tools.base import ToolBase @@ -20,8 +20,7 @@ from ms_agent.utils.stats import (append_stats, build_timing_record, get_stats_path, monotonic, now_iso, summarize_usage) -from ms_agent.utils.thread_util import DaemonThreadPoolExecutor -from omegaconf import DictConfig, ListConfig, OmegaConf +from ms_agent.utils.stream_writer import SubAgentStreamWriter logger = get_logger() @@ -52,6 +51,11 @@ class _AgentToolSpec: env: Optional[Dict[str, str]] run_in_thread: bool run_in_process: bool + dynamic: bool = False + disallowed_tools: Optional[List[str]] = None + max_subtask_output_chars: int = 8192 + run_in_background: bool = False + sync_timeout_s: Optional[float] = None _MESSAGE_FIELDS = set(Message.__dataclass_fields__.keys()) @@ -71,9 +75,17 @@ def _message_from_data(data: Any) -> Message: def _build_sub_agent(spec: _AgentToolSpec, default_trust_remote_code: bool): if spec.inline_config is not None: - config_override = OmegaConf.create(spec.inline_config) + container = _to_container(spec.inline_config) + base_override = OmegaConf.create(container) if isinstance( + container, dict) else OmegaConf.create({}) else: - config_override = None + base_override = OmegaConf.create({}) + # Sub-agents default snapshots off in LLMAgent unless enable_snapshots is set + # on the merged agent config (e.g. in the sub-agent YAML). + config_override = OmegaConf.merge( + base_override, + OmegaConf.create({'ms_agent_subagent': True}), + ) trust_remote_code = spec.trust_remote_code if trust_remote_code is None: @@ -90,6 +102,16 @@ def _build_sub_agent(spec: _AgentToolSpec, default_trust_remote_code: bool): generation_cfg = getattr(agent.config, 'generation_config', DictConfig({})) agent.config.generation_config = generation_cfg + + if spec.disallowed_tools: + tools_cfg = getattr(agent.config, 'tools', None) + if tools_cfg is not None: + tools_dict = OmegaConf.to_container(tools_cfg, resolve=True) + if isinstance(tools_dict, dict): + for key in spec.disallowed_tools: + tools_dict.pop(key, None) + agent.config.tools = OmegaConf.create(tools_dict) + return agent @@ -184,33 +206,86 @@ def __init__(self, config: DictConfig, **kwargs): self._enable_stats = False self._specs: Dict[str, _AgentToolSpec] = {} self._server_tools: Dict[str, List[Tool]] = {} - self._thread_executor: Optional[ThreadPoolExecutor] = None - self._thread_max_workers: int = 0 self._chunk_cb: Optional[Callable[..., Any]] = None self._active_processes: Dict[str, mp.Process] = {} self._active_processes_lock = threading.Lock() + self._task_manager = None + self._watcher_tasks: set = set() + # call_id -> (asyncio.Task, spec, payload, escape_event) + self._active_sync_tasks: Dict[str, Any] = {} + # effective_call_id -> stream file path (set during _run_agent, consumed by call_tool) + self._stream_paths: Dict[str, str] = {} self._load_specs() - self._init_thread_pool_config() - - def _init_thread_pool_config(self): - tools_cfg = getattr(self.config, 'tools', DictConfig({})) - agent_tools_cfg = getattr(tools_cfg, 'agent_tools', DictConfig({})) - max_workers = getattr(agent_tools_cfg, 'max_workers', None) - if max_workers is None: - max_workers = os.getenv('AGENT_TOOL_MAX_WORKERS', None) - try: - self._thread_max_workers = int(max_workers) if max_workers else 3 - except Exception: - self._thread_max_workers = 3 @property def enabled(self) -> bool: return bool(self._specs) + _SPLIT_TASK_DESCRIPTION = ( + 'Split complex task into sub tasks and start them, for example, ' + 'split a website generation task into sub tasks, ' + 'you plan the framework, include code files and classes and functions, and give the detail ' + 'information to the system and query field of the subtask, then ' + 'let each subtask to write a single file') + + _SPLIT_TASK_PARAMETERS = { + 'type': 'object', + 'properties': { + 'tasks': { + 'type': + 'array', + 'description': + ('MANDATORY: Each element is a dict, which must contains two fields: ' + '`system`(str) and `query`(str) to start one sub task.'), + }, + 'execution_mode': { + 'type': + 'string', + 'enum': ['sequential', 'parallel'], + 'description': + 'Whether to run sub-tasks sequentially or in parallel.', + }, + }, + 'required': ['tasks'], + 'additionalProperties': False, + } + def _load_specs(self): tools_cfg = getattr(self.config, 'tools', DictConfig({})) + + # Backward compat: if config.tools.split_task exists, register a built-in dynamic spec + if hasattr(tools_cfg, 'split_task'): + split_cfg = tools_cfg.split_task + tag_prefix = getattr(split_cfg, 'tag_prefix', 'worker-') + run_in_thread = bool(getattr(split_cfg, 'run_in_thread', True)) + run_in_process = bool( + getattr(split_cfg, 'run_in_process', run_in_thread)) + builtin_spec = _AgentToolSpec( + tool_name='split_to_sub_task', + description=self._SPLIT_TASK_DESCRIPTION, + parameters=self._SPLIT_TASK_PARAMETERS, + config_path=None, + inline_config=None, + server_name='split_task', + tag_prefix=tag_prefix, + input_mode='text', + request_field='request', + input_template=None, + output_mode='final_message', + max_output_chars=100000, + trust_remote_code=None, + env=None, + run_in_thread=run_in_thread, + run_in_process=run_in_process, + dynamic=True, + max_subtask_output_chars=8192, + ) + self._specs['split_to_sub_task'] = builtin_spec + agent_tools_cfg = getattr(tools_cfg, 'agent_tools', None) if agent_tools_cfg is None: + if self._specs: + self._build_server_index() return if isinstance(agent_tools_cfg, DictConfig) and hasattr( @@ -233,6 +308,8 @@ def _load_specs(self): definitions_list = definitions else: logger.warning('agent_tools configuration is not iterable; skip.') + if self._specs: + self._build_server_index() return for idx, spec_cfg in enumerate(definitions_list): @@ -258,6 +335,9 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], 'agent_tools[%s] missing tool_name/name field, skip.', idx) return None + mode = getattr(cfg, 'mode', None) + is_dynamic = (mode == 'dynamic') + agent_cfg = getattr(cfg, 'agent', None) config_path = getattr(cfg, 'config_path', None) inline_cfg = getattr(cfg, 'config', None) @@ -267,7 +347,7 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], inline_cfg = _to_container( inline_cfg) if inline_cfg is not None else None - if not config_path and inline_cfg is None: + if not is_dynamic and not config_path and inline_cfg is None: logger.warning( 'agent_tools[%s] (%s) missing config_path/config definition.', idx, tool_name) @@ -277,19 +357,22 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], f'Invoke agent "{tool_name}" as a tool.') parameters = getattr(cfg, 'parameters', None) if parameters is None: - parameters = { - 'type': 'object', - 'properties': { - 'request': { - 'type': - 'string', - 'description': - f'Task description forwarded to the sub-agent {tool_name}.' + if is_dynamic: + parameters = self._SPLIT_TASK_PARAMETERS + else: + parameters = { + 'type': 'object', + 'properties': { + 'request': { + 'type': + 'string', + 'description': + f'Task description forwarded to the sub-agent {tool_name}.' + }, }, - }, - 'required': ['request'], - 'additionalProperties': True, - } + 'required': ['request'], + 'additionalProperties': True, + } else: parameters = _to_container(parameters) @@ -302,17 +385,23 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], input_mode = getattr(cfg, 'input_mode', 'text') output_mode = getattr(cfg, 'output_mode', 'final_message') max_chars = int(getattr(cfg, 'max_output_chars', 100000)) + max_subtask_chars = int(getattr(cfg, 'max_subtask_output_chars', 8192)) server_name = getattr(cfg, 'server_name', default_server) trust_remote_code = getattr(cfg, 'trust_remote_code', None) - # Run sub-agent in a background thread to avoid blocking the main event loop - # when underlying LLM SDKs are synchronous. run_in_thread = bool(getattr(cfg, 'run_in_thread', True)) - # Run sub-agent in an isolated process so timed-out calls can be killed. run_in_process = bool(getattr(cfg, 'run_in_process', run_in_thread)) env_cfg = getattr(cfg, 'env', None) env_cfg = _to_container(env_cfg) if env_cfg is not None else None + disallowed_raw = getattr(cfg, 'disallowed_tools', None) + disallowed_tools = _to_container( + disallowed_raw) if disallowed_raw is not None else None + if isinstance(disallowed_tools, list): + disallowed_tools = [str(t) for t in disallowed_tools] + elif disallowed_tools is not None: + disallowed_tools = None + if config_path and not os.path.isabs(config_path): base_dir = getattr(self.config, 'local_dir', None) if base_dir: @@ -336,6 +425,11 @@ def _build_spec(self, cfg: Union[DictConfig, Dict[str, Any]], env=env_cfg, run_in_thread=run_in_thread, run_in_process=run_in_process, + dynamic=is_dynamic, + disallowed_tools=disallowed_tools, + max_subtask_output_chars=max_subtask_chars, + run_in_background=bool(getattr(cfg, 'run_in_background', False)), + sync_timeout_s=float(getattr(cfg, 'sync_timeout_s', 0)) or None, ) def _build_server_index(self): @@ -351,29 +445,13 @@ def _build_server_index(self): self._server_tools = server_map async def connect(self): - # Lazily initialize a dedicated pool for agent tools that opt into - # `run_in_thread`, so we don't consume threads when the tool is unused. - if self._thread_executor is None: - # Use daemon threads to avoid blocking process exit when sub-agent - # calls are cancelled by tool-call timeouts. - self._thread_executor = DaemonThreadPoolExecutor( - max_workers=self._thread_max_workers, - thread_name_prefix='agent_tool_', - ) return None async def cleanup(self): self._terminate_all_active_processes(reason='during AgentTool cleanup') - if self._thread_executor is not None: - try: - try: - self._thread_executor.shutdown( - wait=False, cancel_futures=True) - except TypeError: - self._thread_executor.shutdown(wait=False) - except Exception: - pass - self._thread_executor = None + for t in list(self._watcher_tasks): + t.cancel() + self._watcher_tasks.clear() return None async def get_tools(self) -> Dict[str, Any]: @@ -395,6 +473,55 @@ def _emit_chunk_event(self, event_type: str, data: Dict[str, Any]) -> None: except Exception as exc: # noqa logger.warning(f'AgentTool chunk callback failed: {exc}') + def set_task_manager(self, task_manager) -> None: + self._task_manager = task_manager + + # ── stream-file helpers ──────────────────────────────────────────────── + + def _stream_file_enabled(self) -> bool: + """Return True if sub-agent stream-file writing is enabled. + + Checks in order: + 1. ``config.tools.agent_tools.enable_stream_file`` + 2. ``config.agent_stream_file`` + Defaults to ``False``. + """ + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) + if agent_tools_cfg is not None: + val = getattr(agent_tools_cfg, 'enable_stream_file', None) + if val is not None: + return bool(val) + return bool(getattr(self.config, 'agent_stream_file', False)) + + def _stream_file_dir(self) -> str: + """Return the output directory used for stream files. + + Checks ``config.tools.agent_tools.stream_file_dir`` first, then falls + back to ``config.output_dir``. + """ + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) + if agent_tools_cfg is not None: + override = getattr(agent_tools_cfg, 'stream_file_dir', None) + if override: + return str(override) + return str(getattr(self.config, 'output_dir', './output')) + + def _stream_include_in_result(self) -> bool: + """Return True if the stream-file path should be appended to the tool result. + + Controlled by ``config.tools.agent_tools.stream_include_in_result`` + (defaults to ``True`` when stream files are enabled). + """ + agent_tools_cfg = getattr( + getattr(self.config, 'tools', None), 'agent_tools', None) + if agent_tools_cfg is not None: + val = getattr(agent_tools_cfg, 'stream_include_in_result', None) + if val is not None: + return bool(val) + return True + async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: if tool_name not in self._specs: @@ -408,15 +535,288 @@ async def call_tool(self, server_name: str, *, tool_name: str, call_id = None if isinstance(tool_args, dict) and '__call_id' in tool_args: call_id = tool_args.pop('__call_id', None) + + # Generate a stable effective_call_id early so that _run_agent() can + # store the stream-file path under the same key and we can look it up. + effective_call_id = call_id or f'{tool_name}-{uuid.uuid4().hex[:8]}' + + if spec.dynamic: + return await self._call_dynamic(tool_args, spec) + payload = self._build_payload(tool_args, spec) + + if spec.run_in_background: + return await self._launch_background(payload, spec, + effective_call_id) + use_subprocess = spec.run_in_thread and spec.run_in_process - agent = None if use_subprocess else self._build_agent(spec) - messages = await self._run_agent(agent, payload, spec, call_id=call_id) - return self._format_output(messages, spec) + if use_subprocess: + messages = await self._run_agent( + None, payload, spec, call_id=effective_call_id) + result_str = self._format_output(messages, spec) + return self._maybe_append_stream_path(result_str, + effective_call_id) + + # Pure async/await with optional escape-to-background support. + result = await self._run_sync_escapable(payload, spec, + effective_call_id) + if isinstance(result, str): + # Already formatted: escaped to background, returns async_launched JSON. + return result + result_str = self._format_output(result, spec) + return self._maybe_append_stream_path(result_str, effective_call_id) + + def _maybe_append_stream_path(self, result_str: str, + effective_call_id: str) -> str: + """Append a human- and LLM-readable note about the step-by-step execution + log to *result_str* if streaming is enabled. + + The note is intentionally descriptive so that the parent agent understands + the file contains the sub-agent's incremental thinking/tool-call trace, + not the tool's output. + """ + if not self._stream_file_enabled(): + return result_str + if not self._stream_include_in_result(): + return result_str + path = self._stream_paths.pop(effective_call_id, None) + if path: + result_str += ( + f'\n\n[Note: The sub-agent\'s step-by-step execution trace ' + f'(messages, tool calls, intermediate reasoning) was streamed ' + f'incrementally to: {path}]') + return result_str def _build_agent(self, spec: _AgentToolSpec): return _build_sub_agent(spec, self._trust_remote_code) + async def _run_sync_escapable(self, payload: Any, spec: _AgentToolSpec, + call_id: Optional[str]) -> Any: + """Run sub-agent inline (pure async/await). + + If spec.sync_timeout_s is set, the call auto-escapes to background after + that many seconds. The caller can also trigger escape at any time via + escape_to_background(call_id). + + Returns either the raw messages list (normal completion) or a JSON string + (async_launched, when escaped to background). + """ + escape_event = asyncio.Event() + effective_call_id = call_id or uuid.uuid4().hex[:12] + + run_task = asyncio.create_task( + self._run_agent(None, payload, spec, call_id=effective_call_id)) + + self._active_sync_tasks[effective_call_id] = (run_task, spec, payload, + escape_event) + + try: + if spec.sync_timeout_s and spec.sync_timeout_s > 0: + escape_wait_task = asyncio.create_task(escape_event.wait()) + sleep_task = asyncio.create_task( + asyncio.sleep(spec.sync_timeout_s)) + _, pending = await asyncio.wait( + [run_task, escape_wait_task, sleep_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for t in pending: + t.cancel() + if not run_task.done(): + return await self._escape_running_task( + effective_call_id, run_task, spec, payload) + else: + # No timeout: wait for completion or explicit escape signal. + escape_task = asyncio.create_task(escape_event.wait()) + _, pending = await asyncio.wait( + [run_task, escape_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for t in pending: + t.cancel() + if not run_task.done(): + return await self._escape_running_task( + effective_call_id, run_task, spec, payload) + + return run_task.result() + finally: + self._active_sync_tasks.pop(effective_call_id, None) + + async def _escape_running_task(self, call_id: str, + run_task: 'asyncio.Task[Any]', + spec: _AgentToolSpec, payload: Any) -> str: + """Cancel the in-progress sync task and re-launch it as a background subprocess.""" + if self._task_manager is None: + raise RuntimeError( + f'AgentTool "{spec.tool_name}" tried to escape to background but ' + 'no TaskManager is attached.') + + run_task.cancel() + try: + await run_task + except (asyncio.CancelledError, Exception): + pass + + # Re-launch as background subprocess (same as _launch_background). + return await self._launch_background(payload, spec, call_id) + + def escape_to_background(self, call_id: str) -> bool: + """Signal a running sync call to escape to background. + + Called by TaskControlTool (or any external code) when the LLM or user + decides the task should continue in the background. + + Args: + call_id: The __call_id injected into tool_args, or the auto-generated + hex id returned in the async_launched JSON. + + Returns: + True if the escape signal was delivered, False if call_id not found. + """ + entry = self._active_sync_tasks.get(call_id) + if entry is None: + return False + _, _, _, escape_event = entry + escape_event.set() + return True + + async def _launch_background(self, payload: Any, spec: _AgentToolSpec, + call_id: Optional[str]) -> str: + """Fire-and-forget: start subprocess, register with TaskManager, return immediately.""" + if self._task_manager is None: + raise RuntimeError( + f'AgentTool "{spec.tool_name}" has run_in_background=true but ' + 'no TaskManager is attached. Ensure LLMAgent wires task_manager ' + 'into AgentTool via set_task_manager().') + + ctx = mp.get_context('spawn') + result_queue = ctx.Queue(maxsize=1) + process_payload = self._serialize_payload_for_process(payload) + proc = ctx.Process( + target=_run_agent_in_subprocess, + args=(spec, self._trust_remote_code, process_payload, False, None, + result_queue), + name=f'agent_tool_bg_{spec.tool_name}', + ) + proc.start() + + task_id = self._task_manager.register( + task_type='agent', + tool_name=spec.tool_name, + description=f'{spec.tool_name} (call_id={call_id})', + proc=proc, + ) + self._register_process(task_id, proc) + + async def _watcher(): + try: + result = await self._wait_process_result(proc, result_queue) + if result is None or not result.get('ok'): + err = (result + or {}).get('error', + 'subprocess exited without result') + tb = (result or {}).get('traceback', '') + if tb: + logger.warning(tb) + await self._task_manager.fail(task_id, err) + else: + result_payload = result.get('result', {}) or {} + restored = self._restore_process_result(result_payload) + output = self._format_output(restored, spec) + await self._task_manager.complete(task_id, output) + except Exception as exc: + await self._task_manager.fail(task_id, str(exc)) + finally: + self._unregister_process(task_id) + try: + result_queue.close() + result_queue.join_thread() + except Exception: + pass + + t = asyncio.create_task(_watcher()) + self._watcher_tasks.add(t) + t.add_done_callback(self._watcher_tasks.discard) + + return json.dumps( + { + 'status': 'async_launched', + 'task_id': task_id, + 'tool_name': spec.tool_name, + }, + ensure_ascii=False) + + async def _call_dynamic(self, tool_args: dict, + spec: '_AgentToolSpec') -> str: + tasks = tool_args.get('tasks', []) + execution_mode = tool_args.get('execution_mode', 'sequential') + + base_config = OmegaConf.to_container(self.config, resolve=True) + + async def _run_one(i: int, task: dict) -> str: + system = task.get('system', '') + query = task.get('query', '') + task_config = dict(base_config) if isinstance(base_config, + dict) else {} + # Avoid inheriting the parent agent's snapshot preference into each + # split sub-task; sub-agents use ms_agent_subagent defaults instead. + task_config.pop('enable_snapshots', None) + if 'prompt' not in task_config or not isinstance( + task_config.get('prompt'), dict): + task_config['prompt'] = {} + task_config['prompt']['system'] = system + sub_spec = _AgentToolSpec( + tool_name=f'{spec.tool_name}_sub{i}', + description='', + parameters={}, + config_path=None, + inline_config=task_config, + server_name=spec.server_name, + tag_prefix=spec.tag_prefix, + input_mode='text', + request_field='request', + input_template=None, + output_mode='final_message', + max_output_chars=spec.max_subtask_output_chars, + trust_remote_code=spec.trust_remote_code, + env=spec.env, + run_in_thread=spec.run_in_thread, + run_in_process=spec.run_in_process, + dynamic=False, + disallowed_tools=spec.disallowed_tools, + max_subtask_output_chars=spec.max_subtask_output_chars, + ) + use_subprocess = sub_spec.run_in_thread and sub_spec.run_in_process + agent = None if use_subprocess else self._build_agent(sub_spec) + messages = await self._run_agent(agent, query, sub_spec) + return self._format_output(messages, sub_spec) + + if execution_mode == 'parallel': + results = await asyncio.gather( + *[_run_one(i, task) for i, task in enumerate(tasks)], + return_exceptions=True, + ) + res_list = [] + for i, r in enumerate(results): + if isinstance(r, Exception): + res_list.append(f'SubTask{i} failed with error: {r}') + else: + res_list.append(str(r)) + else: + res_list = [] + for i, task in enumerate(tasks): + try: + r = await _run_one(i, task) + res_list.append(r) + except Exception as e: + res_list.append(f'SubTask{i} failed with error: {e}') + + formatted = '' + for i, content in enumerate(res_list): + if len(content) > spec.max_subtask_output_chars: + content = content[:spec.max_subtask_output_chars] + formatted += f'SubTask{i}:{content}\n' + return formatted + @staticmethod def _terminate_process(proc: Optional[mp.Process], *, reason: str) -> None: if proc is None: @@ -528,13 +928,30 @@ async def _run_agent(self, runtime_agent_tag = getattr(runtime_agent, 'tag', None) runtime_agent_type = getattr(runtime_agent, 'AGENT_NAME', None) + # ── stream-file writer (optional) ────────────────────────────────── + _writer: Optional[SubAgentStreamWriter] = None + if self._stream_file_enabled(): + _effective_call_id = call_id or f'{spec.tool_name}-{uuid.uuid4().hex[:8]}' + _writer = SubAgentStreamWriter( + output_dir=self._stream_file_dir(), + call_id=_effective_call_id, + tool_name=spec.tool_name, + ) + logger.info( + '[stream] %s (call_id=%s) streaming to %s', + spec.tool_name, + _effective_call_id, + _writer.stream_path, + ) + # ─────────────────────────────────────────────────────────────────── + async def _run_and_collect(): nonlocal runtime_agent, runtime_agent_tag, runtime_agent_type if runtime_agent is None: runtime_agent = self._build_agent(spec) runtime_agent_tag = getattr(runtime_agent, 'tag', None) runtime_agent_type = getattr(runtime_agent, 'AGENT_NAME', None) - if self._chunk_cb: + if self._chunk_cb or _writer is not None: result = await runtime_agent.run(payload, stream=True) else: result = await runtime_agent.run(payload) @@ -544,6 +961,8 @@ async def _run_and_collect(): 'call_id': call_id, 'tool_name': spec.tool_name, }) + if _writer is not None: + _writer.on_start(runtime_agent_tag) async for chunk in result: history = chunk self._emit_chunk_event( @@ -552,6 +971,8 @@ async def _run_and_collect(): 'tool_name': spec.tool_name, 'history': chunk, }) + if _writer is not None: + _writer.on_chunk(chunk) if history is not None: self._emit_chunk_event( 'end', { @@ -559,59 +980,59 @@ async def _run_and_collect(): 'tool_name': spec.tool_name, 'history': history, }) + if _writer is not None: + _writer.on_end(history) result = history else: self._emit_chunk_event('start', { 'call_id': call_id, 'tool_name': spec.tool_name, }) + if _writer is not None: + _writer.on_start(runtime_agent_tag) self._emit_chunk_event( 'chunk', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': result, }) + if _writer is not None: + _writer.on_chunk(result) self._emit_chunk_event( 'end', { 'call_id': call_id, 'tool_name': spec.tool_name, 'history': result, }) + if _writer is not None: + _writer.on_end(result) return result - async def _run_in_background(): - # Run sub-agent in a dedicated event loop in a background thread. - def _sync_runner(): - return asyncio.run(_run_and_collect()) - - loop = asyncio.get_running_loop() - if self._thread_executor is not None: - return await loop.run_in_executor(self._thread_executor, - _sync_runner) - return await asyncio.to_thread(_sync_runner) - async def _run_in_subprocess(): nonlocal runtime_agent_tag, runtime_agent_type ctx = mp.get_context('spawn') result_queue = ctx.Queue(maxsize=1) - event_queue = ctx.Queue( - maxsize=128) if self._chunk_cb is not None else None + # Create event_queue when either chunk_cb or stream writer is active + # so that sub-agent progress can be forwarded to both sinks. + need_events = self._chunk_cb is not None or _writer is not None + event_queue = ctx.Queue(maxsize=128) if need_events else None proc: Optional[mp.Process] = None run_id = f'{call_id or "agent_tool"}-{uuid.uuid4().hex[:8]}' def _emit_stream_event(event: Dict[str, Any]) -> None: - if not self._chunk_cb: - return history_payload = event.get('history') if not isinstance(history_payload, dict): return history = self._restore_process_result(history_payload) - self._emit_chunk_event( - 'chunk', { - 'call_id': call_id, - 'tool_name': spec.tool_name, - 'history': history, - }) + if self._chunk_cb: + self._emit_chunk_event( + 'chunk', { + 'call_id': call_id, + 'tool_name': spec.tool_name, + 'history': history, + }) + if _writer is not None: + _writer.on_chunk(history) try: if self._chunk_cb: @@ -619,12 +1040,14 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: 'call_id': call_id, 'tool_name': spec.tool_name, }) + if _writer is not None: + # agent_tag unknown until subprocess completes; pass None + _writer.on_start(None) process_payload = self._serialize_payload_for_process(payload) proc = ctx.Process( target=_run_agent_in_subprocess, args=(spec, self._trust_remote_code, process_payload, - self._chunk_cb - is not None, event_queue, result_queue), + need_events, event_queue, result_queue), name=f'agent_tool_{spec.tool_name}', ) proc.start() @@ -648,9 +1071,10 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: tb = result.get('traceback', '') if tb: logger.warning(tb) - raise RuntimeError( - f'Sub-agent {spec.tool_name} failed: {result.get("error", "unknown error")}' - ) + err_msg = f'Sub-agent {spec.tool_name} failed: {result.get("error", "unknown error")}' + if _writer is not None: + _writer.on_error(err_msg) + raise RuntimeError(err_msg) result_payload = result.get('result', {}) or {} runtime_agent_tag = result_payload.get( 'agent_tag') or runtime_agent_tag @@ -673,12 +1097,19 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: 'tool_name': spec.tool_name, 'history': restored, }) + # Always finalise the writer regardless of _chunk_cb. + if _writer is not None: + _writer.on_end(restored) return restored except asyncio.CancelledError: self._terminate_process(proc, reason='was cancelled') + if _writer is not None: + _writer.on_error('cancelled') raise - except Exception: + except Exception as exc: self._terminate_process(proc, reason='encountered error') + if _writer is not None: + _writer.on_error(str(exc)) raise finally: self._unregister_process(run_id) @@ -704,13 +1135,18 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: if spec.run_in_thread and spec.run_in_process: runner = _run_in_subprocess - elif spec.run_in_thread: - runner = _run_in_background else: runner = _run_and_collect if not self._enable_stats: - return await runner() + result = await runner() + if not spec.run_in_process: + self._save_transcript(result, runtime_agent_tag) + if _writer is not None: + # Store with the same key used by call_tool() to pop it. + store_key = call_id if call_id is not None else _effective_call_id + self._stream_paths[store_key] = _writer.stream_path + return result start_ts = now_iso() start_time = monotonic() @@ -718,6 +1154,11 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: result = None try: result = await runner() + if not spec.run_in_process: + self._save_transcript(result, runtime_agent_tag) + if _writer is not None: + store_key = call_id if call_id is not None else _effective_call_id + self._stream_paths[store_key] = _writer.stream_path return result except BaseException as exc: status = 'cancelled' if isinstance( @@ -749,6 +1190,25 @@ def _emit_stream_event(event: Dict[str, Any]) -> None: f'Failed to write agent tool stats for {spec.tool_name}: {exc}' ) + def _save_transcript(self, messages: Any, + agent_tag: Optional[str]) -> None: + if not isinstance(messages, list) or not agent_tag: + return + try: + output_dir = getattr(self.config, 'output_dir', './output') + subagents_dir = os.path.join(output_dir, 'subagents') + os.makedirs(subagents_dir, exist_ok=True) + path = os.path.join(subagents_dir, f'{agent_tag}.jsonl') + with open(path, 'w', encoding='utf-8') as f: + for msg in messages: + if hasattr(msg, 'to_dict'): + f.write( + json.dumps(msg.to_dict(), ensure_ascii=False) + + '\n') + except Exception as exc: + logger.warning( + f'Failed to save sub-agent transcript for {agent_tag}: {exc}') + def _build_payload(self, tool_args: dict, spec: _AgentToolSpec): if spec.input_mode == 'messages': field = spec.request_field or 'messages' diff --git a/ms_agent/tools/base.py b/ms_agent/tools/base.py index 12ece9948..18e66d175 100644 --- a/ms_agent/tools/base.py +++ b/ms_agent/tools/base.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from abc import abstractmethod +from omegaconf import DictConfig from typing import Any, Dict from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR -from omegaconf import DictConfig class ToolBase: diff --git a/ms_agent/tools/code/code_executor.py b/ms_agent/tools/code/code_executor.py index 1df89e39c..cfe43c3b5 100644 --- a/ms_agent/tools/code/code_executor.py +++ b/ms_agent/tools/code/code_executor.py @@ -1,17 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json import socket +from omegaconf import DictConfig from pathlib import Path from typing import Any, Dict, Optional, Union -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.tools.code.sandbox_manager import SandboxManagerFactory from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from ms_agent.utils.utils import install_package -from omegaconf import DictConfig logger = get_logger() @@ -117,7 +117,8 @@ def _build_sandbox_config( self, config) -> Union['DockerNotebookConfig', 'DockerSandboxConfig']: """Build sandbox configuration from agent config""" - from ms_enclave.sandbox.model import DockerNotebookConfig, DockerSandboxConfig, SandboxType + from ms_enclave.sandbox.model import (DockerNotebookConfig, + DockerSandboxConfig, SandboxType) # Get sandbox-specific config or use defaults if isinstance(config, DictConfig) and hasattr( diff --git a/ms_agent/tools/code/local_code_executor.py b/ms_agent/tools/code/local_code_executor.py index 65de0556e..63cd4d3b9 100644 --- a/ms_agent/tools/code/local_code_executor.py +++ b/ms_agent/tools/code/local_code_executor.py @@ -2,19 +2,23 @@ import asyncio.subprocess as ai_subprocess import inspect import io +import json import os +import shlex import shutil import time from contextlib import redirect_stderr, redirect_stdout from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger +from ms_agent.utils.artifact_manager import ArtifactManager from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR from ms_agent.utils.utils import install_package +from ms_agent.utils.workspace_policy import (WorkspacePolicyError, + WorkspacePolicyKernel) logger = get_logger() @@ -231,7 +235,8 @@ class LocalCodeExecutionTool(ToolBase): def __init__(self, config): super().__init__(config) self.output_dir = Path( - getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR)).expanduser() + getattr(config, 'output_dir', + DEFAULT_OUTPUT_DIR)).expanduser().resolve() self.output_dir.mkdir(parents=True, exist_ok=True) self.tool_config = getattr( @@ -250,6 +255,39 @@ def __init__(self, config): self.shell_env = shell_env self._kernel_lock = asyncio.Lock() self._initialized = False + self._task_manager = None + self._watcher_tasks: Set[asyncio.Task] = set() + + wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) + extra_allow: List[str] = [] + deny_globs = None + if wp is not None: + extra_allow = list(getattr(wp, 'allow_roots', []) or []) + dg = getattr(wp, 'deny_globs', None) + if dg: + deny_globs = list(dg) + shell_cfg = getattr(self.tool_config, 'shell', + None) if self.tool_config else None + shell_mode = getattr( + shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', + False)) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', + 8192)) if shell_cfg else 8192 + self._policy = WorkspacePolicyKernel( + self.output_dir, + extra_allow_roots=extra_allow, + deny_globs=deny_globs, + shell_default_mode=str(shell_mode), + shell_network_enabled=net, + max_command_chars=max_cmd, + ) + max_kb = 256 + if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): + max_kb = int(shell_cfg.max_output_kb) + self._artifacts = ArtifactManager( + self.output_dir, max_combined_bytes=max_kb * 1024) self.exclude_func( getattr(getattr(config, 'tools', None), 'code_executor', None)) @@ -264,6 +302,10 @@ def __init__(self, config): logger.info('LocalCodeExecutionTool initialized (ipykernel based)') + def set_task_manager(self, task_manager) -> None: + """Attach process-wide TaskManager for background shell (see shell_executor).""" + self._task_manager = task_manager + def _check_dependencies(self) -> None: import importlib @@ -341,6 +383,10 @@ async def connect(self) -> None: self._initialized = True async def cleanup(self) -> None: + for t in list(self._watcher_tasks): + if not t.done(): + t.cancel() + self._watcher_tasks.clear() if not self._initialized: return await self.kernel_session.stop() @@ -421,9 +467,12 @@ async def _get_tools_inner(self) -> Dict[str, Any]: Tool( tool_name='shell_executor', server_name='code_executor', - description=('Execute shell commands locally using bash. ' - 'Supports basic shell operations like ls, ' - 'cd, mkdir, rm, etc. '), + description= + ('Execute shell commands locally under the workspace output directory (cwd). ' + 'Subject to policy (read_only vs workspace_write, network toggle). ' + 'Large stdout/stderr may be spilled to .ms_agent_artifacts. ' + 'Use run_in_background=true to return immediately with task_id; poll via task notifications.' + ), parameters={ 'type': 'object', 'properties': { @@ -435,7 +484,19 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'type': 'integer', 'description': 'Execution timeout in seconds', 'default': self._shell_timeout - } + }, + 'run_in_background': { + 'type': 'boolean', + 'description': + 'If true, start the command asynchronously and return task_id (requires TaskManager).', + 'default': False, + }, + '__call_id': { + 'type': + 'string', + 'description': + 'Optional correlation id (injected by host when supported).', + }, }, 'required': ['command'], 'additionalProperties': False @@ -630,16 +691,118 @@ def _exec_code(): async def shell_executor(self, command: str, - timeout: Optional[int] = None) -> str: + timeout: Optional[int] = None, + run_in_background: bool = False, + __call_id: Optional[str] = None) -> str: exec_timeout = timeout or self._shell_timeout + call_id = __call_id or f'shell-{os.urandom(4).hex()}' + + try: + self._policy.assert_shell_command_allowed(command) + except WorkspacePolicyError as e: + return json.dumps( + { + 'success': False, + 'error': str(e) + }, + ensure_ascii=False, + indent=2, + ) + + shell_cmd = self._prepare_shell_command(command) + + if run_in_background: + if self._task_manager is None: + return json.dumps( + { + 'success': + False, + 'error': + 'run_in_background requires TaskManager (host must wire LLMAgent.task_manager).', + }, + ensure_ascii=False, + indent=2, + ) + try: + process = await asyncio.create_subprocess_shell( + shell_cmd, + stdout=ai_subprocess.PIPE, + stderr=ai_subprocess.PIPE, + cwd=str(self._policy.workspace_root), + env=self.shell_env, + ) + except FileNotFoundError as exc: + return json.dumps( + { + 'success': False, + 'error': f'Shell not available: {exc}' + }, + ensure_ascii=False, + indent=2, + ) + + task_id = self._task_manager.register( + task_type='shell', + tool_name='shell_executor', + description=command[:200], + proc=process, + ) + + async def _watcher() -> None: + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=exec_timeout) + stdout_text = _coerce_str(stdout).strip('\n') + stderr_text = _coerce_str(stderr).strip('\n') + success = process.returncode == 0 + payload = { + 'success': success, + 'output': stdout_text, + 'error': stderr_text or None, + 'return_code': process.returncode, + } + text = self._artifacts.pack_json_shell_result( + tool_name='shell_executor', + call_id=task_id, + payload=payload, + ) + await self._task_manager.complete(task_id, text) + except asyncio.TimeoutError: + try: + process.kill() + await process.communicate() + except Exception: # noqa: B902 + pass + await self._task_manager.fail( + task_id, + f'Shell command timed out after {exec_timeout} seconds', + ) + except Exception as exc: # noqa: B902 + await self._task_manager.fail(task_id, str(exc)) + + t = asyncio.create_task(_watcher()) + self._watcher_tasks.add(t) + t.add_done_callback(self._watcher_tasks.discard) + + return json.dumps( + { + 'status': 'async_launched', + 'task_id': task_id, + 'tool_name': 'shell_executor', + 'call_id': call_id, + }, + ensure_ascii=False, + indent=2, + ) try: process = await asyncio.create_subprocess_shell( - command, + shell_cmd, stdout=ai_subprocess.PIPE, stderr=ai_subprocess.PIPE, - cwd=str(self.output_dir), - env=self.shell_env) + cwd=str(self._policy.workspace_root), + env=self.shell_env, + ) except FileNotFoundError as exc: return json.dumps( { @@ -647,7 +810,8 @@ async def shell_executor(self, 'error': f'Shell not available: {exc}' }, ensure_ascii=False, - indent=2) + indent=2, + ) try: stdout, stderr = await asyncio.wait_for( @@ -666,20 +830,23 @@ async def shell_executor(self, f'Shell command timed out after {exec_timeout} seconds' }, ensure_ascii=False, - indent=2) + indent=2, + ) stdout_text = _coerce_str(stdout).strip('\n') stderr_text = _coerce_str(stderr).strip('\n') success = process.returncode == 0 - return json.dumps( - { - 'success': success, - 'output': stdout_text, - 'error': stderr_text or None, - 'return_code': process.returncode - }, - ensure_ascii=False, - indent=2) + payload = { + 'success': success, + 'output': stdout_text, + 'error': stderr_text or None, + 'return_code': process.returncode, + } + return self._artifacts.pack_json_shell_result( + tool_name='shell_executor', + call_id=call_id, + payload=payload, + ) async def file_operation(self, operation: str, diff --git a/ms_agent/tools/code/sandbox_manager.py b/ms_agent/tools/code/sandbox_manager.py index 9744ca7f8..4d3d843ce 100644 --- a/ms_agent/tools/code/sandbox_manager.py +++ b/ms_agent/tools/code/sandbox_manager.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig from typing import Union from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -58,7 +58,8 @@ async def create_manager( Raises: ValueError: If sandbox mode is unknown """ - from ms_enclave.sandbox.manager import HttpSandboxManager, LocalSandboxManager + from ms_enclave.sandbox.manager import (HttpSandboxManager, + LocalSandboxManager) # Extract sandbox configuration if isinstance(config, DictConfig) and hasattr( diff --git a/ms_agent/tools/code_server/lsp_code_server.py b/ms_agent/tools/code_server/lsp_code_server.py index df1e043e5..ac7f3fbaf 100644 --- a/ms_agent/tools/code_server/lsp_code_server.py +++ b/ms_agent/tools/code_server/lsp_code_server.py @@ -1,11 +1,11 @@ import asyncio +import json import os import shutil import sys from pathlib import Path from typing import Any, Dict, List, Optional -import json from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger from ms_agent.utils.constants import (DEFAULT_INDEX_DIR, DEFAULT_LOCK_DIR, diff --git a/ms_agent/tools/docling/chunker.py b/ms_agent/tools/docling/chunker.py index 3781cb76b..5a7c894c6 100644 --- a/ms_agent/tools/docling/chunker.py +++ b/ms_agent/tools/docling/chunker.py @@ -1,5 +1,3 @@ -from typing import Iterable, Iterator, List, Union - from docling_core.transforms.chunker import BaseChunk, DocChunk from docling_core.transforms.chunker.hierarchical_chunker import ( ChunkingDocSerializer, ChunkingSerializerProvider) @@ -10,11 +8,12 @@ from docling_core.transforms.serializer.markdown import MarkdownParams from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel -from ms_agent.utils.logger import get_logger from rich.console import Console from rich.panel import Panel +from typing import Iterable, Iterator, List, Union from modelscope import AutoTokenizer +from ms_agent.utils.logger import get_logger logger = get_logger() diff --git a/ms_agent/tools/docling/doc_loader.py b/ms_agent/tools/docling/doc_loader.py index 5daf1652b..daac553e5 100644 --- a/ms_agent/tools/docling/doc_loader.py +++ b/ms_agent/tools/docling/doc_loader.py @@ -2,9 +2,6 @@ # yapf: disable import ast import os -from typing import Dict, Iterator, List, Optional, Tuple, Union -from unittest.mock import patch as mock_patch - from docling.backend.html_backend import HTMLDocumentBackend from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.base_models import InputFormat @@ -18,6 +15,9 @@ from docling.models.table_structure_model import TableStructureModel from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem +from typing import Dict, Iterator, List, Optional, Tuple, Union +from unittest.mock import patch as mock_patch + from ms_agent.tools.docling.doc_postprocess import PostProcess from ms_agent.tools.docling.patches import (download_models_ms, download_models_pic_classifier_ms, diff --git a/ms_agent/tools/docling/doc_postprocess.py b/ms_agent/tools/docling/doc_postprocess.py index 2e21051d9..2b66dd95e 100644 --- a/ms_agent/tools/docling/doc_postprocess.py +++ b/ms_agent/tools/docling/doc_postprocess.py @@ -1,6 +1,5 @@ -from typing import Union - from docling_core.types import DoclingDocument +from typing import Union class PostProcess: diff --git a/ms_agent/tools/docling/patches.py b/ms_agent/tools/docling/patches.py index b2ec77afd..4191e706b 100644 --- a/ms_agent/tools/docling/patches.py +++ b/ms_agent/tools/docling/patches.py @@ -1,10 +1,10 @@ # flake8: noqa import sys -from pathlib import Path - from bs4 import Tag from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel, ImageRef +from pathlib import Path + from ms_agent.utils.logger import get_logger from ms_agent.utils.utils import (load_image_from_uri_to_pil, load_image_from_url_to_pil, validate_url) diff --git a/ms_agent/tools/fetch_playwright_fallback.py b/ms_agent/tools/fetch_playwright_fallback.py new file mode 100644 index 000000000..b173217df --- /dev/null +++ b/ms_agent/tools/fetch_playwright_fallback.py @@ -0,0 +1,190 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Optional headless Chromium fetch for URLs where Jina + direct HTTP yield empty +or obviously client-rendered shells. + +Requires: ``pip install playwright`` and ``playwright install chromium``. +If Playwright is not installed, helpers return empty string without raising. + +Performance: Playwright ``Browser`` must be used from the creating thread only. +We keep **one browser per thread** (e.g. each ``ThreadPoolExecutor`` worker) and +reuse it across URLs instead of launching Chromium for every fetch. +""" +from __future__ import annotations + +import atexit +import os +import re +import threading +from typing import Dict, List, Tuple + +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_MAX_INNER_TEXT_CHARS = 100_000 + +_tls = threading.local() +_registry_lock = threading.Lock() +# thread_id -> (sync_playwright handle, Browser) for atexit cleanup +_pw_by_thread: Dict[int, Tuple[object, object]] = {} + + +def _chromium_launch_args() -> List[str]: + args: List[str] = [ + '--disable-extensions', + '--blink-settings=imagesEnabled=false', + ] + if os.getenv('MS_AGENT_PLAYWRIGHT_NO_SANDBOX', '').lower() in ( + '1', + 'true', + 'yes', + ): + args.extend(('--no-sandbox', '--disable-setuid-sandbox')) + return args + + +def _invalidate_thread_playwright_unlocked() -> None: + """Drop thread-local Playwright; caller must hold no registry lock if mutating _pw_by_thread.""" + tid = threading.get_ident() + pw = getattr(_tls, 'pw', None) + br = getattr(_tls, 'browser', None) + _tls.pw = None + _tls.browser = None + with _registry_lock: + _pw_by_thread.pop(tid, None) + if br is not None: + try: + br.close() + except Exception: + pass + if pw is not None: + try: + pw.stop() + except Exception: + pass + + +def _atexit_close_all_playwright() -> None: + with _registry_lock: + items = list(_pw_by_thread.values()) + _pw_by_thread.clear() + for pw, browser in items: + try: + browser.close() + except Exception: + pass + try: + pw.stop() + except Exception: + pass + + +atexit.register(_atexit_close_all_playwright) + + +def _thread_browser() -> object: + """Return a Chromium ``Browser`` for this thread, creating it lazily.""" + br = getattr(_tls, 'browser', None) + if br is not None: + try: + if br.is_connected(): + return br + except Exception: + pass + _invalidate_thread_playwright_unlocked() + br = None + + try: + from playwright.sync_api import sync_playwright + except ImportError: + logger.debug( + 'playwright is not installed; skip headless fetch. ' + 'Install with: pip install playwright && playwright install chromium' + ) + raise RuntimeError('playwright not installed') from None + + pw = sync_playwright().start() + browser = pw.chromium.launch( + headless=True, + args=_chromium_launch_args(), + ) + _tls.pw = pw + _tls.browser = browser + with _registry_lock: + _pw_by_thread[threading.get_ident()] = (pw, browser) + return browser + + +def try_playwright_inner_text( + url: str, + timeout_ms: int, + *, + settle_ms: int = 350, +) -> str: + """ + Load URL in headless Chromium and return ``document.body.innerText``. + + Reuses one browser per thread. Returns empty string on missing dependency, + timeout, or navigation error. + """ + if not url.startswith(('http://', 'https://')): + return '' + settle_ms = max(0, int(settle_ms)) + try: + from playwright.sync_api import sync_playwright # noqa: F401 + except ImportError: + logger.debug( + 'playwright is not installed; skip headless fetch. ' + 'Install with: pip install playwright && playwright install chromium' + ) + return '' + + text = '' + try: + browser = _thread_browser() + page = browser.new_page() + try: + page.set_default_timeout(timeout_ms) + page.goto(url, wait_until='domcontentloaded', timeout=timeout_ms) + if settle_ms: + page.wait_for_timeout(settle_ms) + raw = page.evaluate("""() => { + const b = document.body; + if (!b) return ''; + return b.innerText || ''; + }""") + if isinstance(raw, str): + text = raw[:_MAX_INNER_TEXT_CHARS] + finally: + try: + page.close() + except Exception: + pass + except RuntimeError: + return '' + except Exception as e: + logger.debug(f'Playwright fetch failed for {url[:80]!r}: {e}') + try: + _invalidate_thread_playwright_unlocked() + except Exception: + pass + return '' + + return text + + +def looks_like_spa_shell_html(raw_html: str) -> bool: + """Heuristic: HTML suggests JS-only app or empty mount root.""" + if not raw_html or len(raw_html) < 80: + return False + low = raw_html.lower() + if any( + x in low for x in ('enable javascript', 'javascript is required', + 'you need to enable javascript')): + return True + if re.search(r']+\bid=["\']root["\'][^>]*>\s*', low): + return True + if re.search(r']+\bid=["\']app["\'][^>]*>\s*', low): + return True + return False diff --git a/ms_agent/tools/filesystem_tool.py b/ms_agent/tools/filesystem_tool.py index a107a7c59..0126f4f02 100644 --- a/ms_agent/tools/filesystem_tool.py +++ b/ms_agent/tools/filesystem_tool.py @@ -1,32 +1,88 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio +import base64 import fnmatch +import json import os import re import shutil from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional -import json from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool from ms_agent.tools.base import ToolBase -from ms_agent.utils import MAX_CONTINUE_RUNS, get_logger, retry +from ms_agent.utils import get_logger +from ms_agent.utils.artifact_manager import ArtifactManager from ms_agent.utils.constants import DEFAULT_INDEX_DIR, DEFAULT_OUTPUT_DIR -from openai import OpenAI +from ms_agent.utils.workspace_policy import (WorkspacePolicyError, + WorkspacePolicyKernel) logger = get_logger() +_FS_TOOL_ALIASES = { + 'read': 'read_file', + 'edit': 'edit_file', + 'write': 'write_file' +} + +_TEXT_SUFFIXES = { + '.py', + '.md', + '.txt', + '.yaml', + '.yml', + '.json', + '.toml', + '.cfg', + '.ini', + '.sh', + '.bash', + '.js', + '.ts', + '.tsx', + '.jsx', + '.css', + '.html', + '.xml', + '.rs', + '.go', + '.java', + '.c', + '.h', + '.cpp', + '.hpp', + '.cs', + '.rb', + '.php', + '.sql', + '.vue', + '.svelte', + '.m', + '.swift', + '.kt', + '.gradle', + '.properties', + '.env', + '.gitignore', + '.dockerignore', + 'Dockerfile', +} + class FileSystemTool(ToolBase): """A file system operation tool""" - # Directories to exclude from file operations - EXCLUDED_DIRS = { - 'node_modules', 'dist', '.git', '__pycache__', '.venv', 'venv' + MAX_READ_BYTES = 10 * 1024 * 1024 # 10 MB per file + IMAGE_EXTENSIONS = frozenset({'png', 'jpg', 'jpeg', 'gif', 'webp'}) + # Curly quote → straight quote mapping for fuzzy matching + CURLY_QUOTE_MAP = { + '\u2018': "'", + '\u2019': "'", # ' ' + '\u201c': '"', + '\u201d': '"', # " " } - # File prefixes to exclude - EXCLUDED_FILE_PREFIXES = ('.', '..', '__pycache__') SYSTEM_FOR_ABBREVIATIONS = """你是一个帮我简化文件信息并返回缩略的机器人,你需要根据输入文件内容来生成压缩过的文件内容。 @@ -45,14 +101,15 @@ class FileSystemTool(ToolBase): def __init__(self, config, **kwargs): super().__init__(config) self.exclude_func(getattr(config.tools, 'file_system', None)) + if self.include_functions: + self.include_functions = [ + _FS_TOOL_ALIASES.get(n, n) for n in self.include_functions + ] + if self.exclude_functions: + self.exclude_functions = [ + _FS_TOOL_ALIASES.get(n, n) for n in self.exclude_functions + ] self.output_dir = getattr(config, 'output_dir', DEFAULT_OUTPUT_DIR) - if self.exclude_functions and 'edit_file' not in self.exclude_functions \ - or self.include_functions and 'edit_file' in self.include_functions: - self.edit_file_config = getattr(config.tools.file_system, - 'edit_file_config', None) - self.edit_client = OpenAI( - api_key=self.edit_file_config.api_key, - base_url=self.edit_file_config.base_url) self.trust_remote_code = kwargs.get('trust_remote_code', False) self.allow_read_all_files = getattr( getattr(config.tools, 'file_system', {}), 'allow_read_all_files', @@ -66,6 +123,48 @@ def __init__(self, config, **kwargs): self.system = self.SYSTEM_FOR_ABBREVIATIONS if hasattr(self.config.tools.file_system, 'system_for_abbreviations'): self.system = self.config.tools.file_system.system_for_abbreviations + # read_file dedup only: {real_path: {"mtime", "offset", "limit"}} + self._read_cache: dict[str, dict] = {} + + fs_cfg = getattr(config.tools, 'file_system', None) + self._grep_timeout = int(getattr(fs_cfg, 'grep_timeout_s', 120) or 120) + self._default_grep_head = int( + getattr(fs_cfg, 'grep_head_limit', 250) or 250) + self._glob_max_files = int( + getattr(fs_cfg, 'glob_max_files', 100) or 100) + + wp = getattr(getattr(config, 'tools', None), 'workspace_policy', None) + extra = list(getattr(wp, 'allow_roots', []) or []) if wp else [] + deny = list(getattr(wp, 'deny_globs', []) or []) if wp else [] + + shell_cfg = getattr( + getattr(config.tools, 'code_executor', None), 'shell', None) + shell_mode = getattr( + shell_cfg, 'default_mode', + 'workspace_write') if shell_cfg else 'workspace_write' + net = bool(getattr(shell_cfg, 'network_enabled', + False)) if shell_cfg else False + max_cmd = int(getattr(shell_cfg, 'max_command_chars', + 8192)) if shell_cfg else 8192 + + _out_p = Path(self.output_dir).expanduser().resolve() + try: + _out_p.mkdir(parents=True, exist_ok=True) + except OSError: + pass + self._fs_policy = WorkspacePolicyKernel( + _out_p, + extra_allow_roots=extra, + deny_globs=deny if deny else None, + shell_default_mode=str(shell_mode), + shell_network_enabled=net, + max_command_chars=max_cmd, + ) + max_kb = 256 + if shell_cfg and getattr(shell_cfg, 'max_output_kb', None): + max_kb = int(shell_cfg.max_output_kb) + self._fs_artifacts = ArtifactManager( + _out_p, max_combined_bytes=max_kb * 1024) async def connect(self): logger.warning_once( @@ -76,9 +175,17 @@ async def _get_tools_inner(self): tools = { 'file_system': [ Tool( - tool_name='create_directory', + tool_name='write_file', server_name='file_system', - description='Create a directory', + description= + ('Write content to a file. Creates the file if it does not exist, ' + 'or overwrites it if it does.\n\n' + 'Usage:\n' + '- Prefer `edit_file` for modifying existing files — it only changes the relevant section.\n' + '- Use this tool to create new files or perform a complete rewrite.\n' + '- Parent directories are created automatically if they do not exist.\n\n' + 'No prior `read_file` is required; the path is resolved and the given `content` is written as-is.' + ), parameters={ 'type': 'object', 'properties': { @@ -86,37 +193,34 @@ async def _get_tools_inner(self): 'type': 'string', 'description': - 'The relative path of the directory to create', - } - }, - 'required': ['path'], - 'additionalProperties': False - }), - Tool( - tool_name='write_file', - server_name='file_system', - description='Write content into a file', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': 'string', - 'description': 'The relative path of the file', + 'The relative path of the file to write', }, 'content': { - 'type': 'string', - 'description': 'The content of the file', + 'type': + 'string', + 'description': + 'The full content to write into the file', }, }, 'required': ['path', 'content'], 'additionalProperties': False }), Tool( - tool_name='read_abbreviation_file', + tool_name='read_file', server_name='file_system', description= - 'Read the abbreviation content of file(s). If the information is not enough, ' - 'read the original file by `read_file`', + ('Read the content of one or more files.\n\n' + '- `paths`: list of relative file paths to read (preferred).\n' + '- `path`: single relative file path (alias when the model passes one file).\n' + '- For image files (png/jpg/jpeg/gif/webp), returns base64-encoded content.\n' + '- `offset`: line number to start reading from (1-based). ' + 'Only effective when paths has exactly one element. Omit to read from the beginning.\n' + '- `limit`: number of lines to read. ' + 'Only effective when paths has exactly one element. Omit to read to the end.\n' + '- `abbreviate`: if true, use an LLM to return a condensed summary of each file ' + 'instead of the raw content. Cached after first call. ' + 'Use this for a quick structural overview; read the full file if more detail is needed.' + ), parameters={ 'type': 'object', 'properties': { @@ -127,300 +231,170 @@ async def _get_tools_inner(self): 'type': 'string' }, 'description': - 'List of relative file path(s) to read, format: {"paths": ["file1", "file2"]}"]}', + 'List of relative file path(s) to read. ' + 'Use this OR `path` (single file).', }, - }, - 'required': ['paths'], - 'additionalProperties': False - }), - Tool( - tool_name='read_file', - server_name='file_system', - description= - 'Read the content of file(s). When reading a single file, optionally specify line range.', - parameters={ - 'type': 'object', - 'properties': { - 'paths': { + 'path': { 'type': - 'array', - 'items': { - 'type': 'string' - }, + 'string', 'description': - 'List of relative file path(s) to read, format: {"paths": ["file1", "file2"]}"]}', + 'Single relative file path to read (alias for `paths` of length 1).', }, - 'start_line': { + 'offset': { 'type': 'integer', 'description': - 'Start line number (1-based, inclusive). Only effective when paths has exactly one ' - 'element. 0 or omit to read from beginning.', + 'Line number to start reading from (1-based). ' + 'Only provide if the file is too large to read at once.', }, - 'end_line': { + 'limit': { 'type': 'integer', 'description': - 'End line number (1-based, inclusive). Only effective when paths has exactly one ' - 'element. Omit to read to the end.', + 'Number of lines to read. ' + 'Only provide if the file is too large to read at once.', }, - }, - 'required': ['paths'], - 'additionalProperties': False - }), - Tool( - tool_name='list_files', - server_name='file_system', - description='List all files in a directory', - parameters={ - 'type': 'object', - 'properties': { - 'path': { + 'abbreviate': { 'type': - 'string', + 'boolean', 'description': - "The path to list files, if path is None or '' or not given, " - 'the root dir will be used as path.', - } + 'If true, return an LLM-generated summary instead of raw content. ' + 'Useful for large files or quick structural overview.', + }, }, 'required': [], 'additionalProperties': False }), - Tool( - tool_name='delete_file_or_dir', - server_name='file_system', - description='Delete one file or one directory', - parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': 'string', - 'description': 'The relative path to delete', - } - }, - 'required': ['path'], - 'additionalProperties': False - }), Tool( tool_name='edit_file', server_name='file_system', description= - ('Use this tool to make an edit to an existing file.\n\n' - 'This will be read by a less intelligent model, which will quickly apply the edit. ' - 'You should make it clear what the edit is, while also minimizing the unchanged code you write.\n' - 'When writing the edit, you should specify each edit in sequence, with the special comment ' - '// ... existing code ... to represent unchanged code in between edited lines.\n\n' - 'For example:\n\n// ... existing code ...\nFIRST_EDIT\n// ... existing code ...\n' - 'SECOND_EDIT\n// ... existing code ...\nTHIRD_EDIT\n// ... existing code ...\n\n' - 'You should still bias towards repeating as few lines of the original file ' - 'as possible to convey the change.\n' - 'But, each edit should contain minimally sufficient context of unchanged lines ' - "around the code you're editing to resolve ambiguity.\n" - 'DO NOT omit spans of pre-existing code (or comments) without using the ' - '// ... existing code ... comment to indicate its absence. ' - 'If you omit the existing code comment, the model may inadvertently delete these lines.\n' - 'If you plan on deleting a section, you must provide context before and after to delete it. ' - 'If the initial code is ```code \\n Block 1 \\n Block 2 \\n Block 3 \\n code```, ' - 'and you want to remove Block 2, you would output ' - '```// ... existing code ... \\n Block 1 \\n Block 3 \\n // ... existing code ...```.\n' - 'Make sure it is clear what the edit should be, and where it should be applied.\n' - 'Make edits to a file in a single edit_file call ' - 'instead of multiple edit_file calls to the same file. ' - 'The apply model can handle many distinct edits at once.' + ('Edit an existing file by replacing an exact string with new content.\n\n' + 'The tool reads the file from disk and checks that `old_string` appears in the current ' + 'contents before applying the edit — you do not need a prior `read_file` call for ' + 'staleness. Still use `read_file` or `grep` when you need the exact snippet in the ' + 'conversation so you can form a correct `old_string`.\n\n' + 'You must provide the exact text to find (`old_string`) and the replacement (`new_string`).\n' + '`old_string` must match the file content EXACTLY — including whitespace and line breaks.\n' + 'If `old_string` appears multiple times and `replace_all` is false, the call will fail ' + 'with the match count so you can add more context to make it unique.\n\n' + 'Special case — `old_string=""`:\n' + '- File does not exist: creates the file with `new_string` as its content.\n' + '- File exists and is empty: fills it with `new_string`.\n' + '- File exists and has content: returns an error. Use `write_file` for a full rewrite.' ), parameters={ 'type': 'object', 'properties': { 'path': { - 'type': 'string', - 'description': - 'Path of the target file to modify.' - }, - 'instructions': { 'type': 'string', 'description': - ('A single sentence instruction describing ' - 'what you are going to do for the sketched edit. ' - 'This is used to assist the less intelligent model in applying the edit. ' - 'Use the first person to describe what you are going to do. ' - 'Use it to disambiguate uncertainty in the edit.' - ) + 'The relative path of the file to edit.', }, - 'code_edit': { - 'type': - 'string', - 'description': - ('Specify ONLY the precise lines of code that you wish to edit. ' - 'NEVER specify or write out unchanged code. ' - 'Instead, represent all unchanged code using the comment of the language ' - "you're editing in - example: // ... existing code ..." - ) - } - }, - 'required': ['path', 'instructions', 'code_edit'] - }), - Tool( - tool_name='search_file_content', - server_name='file_system', - description= - 'Search for content in files using literal text or regex patterns. ' - 'Automatically detects and supports both literal string matching and regex pattern matching. ' - 'Returns matching files with line numbers and surrounding context.', - parameters={ - 'type': 'object', - 'properties': { - 'content': { + 'old_string': { 'type': 'string', 'description': - 'The content/text or regex pattern to search for. ' - 'Supports both literal strings and regex patterns automatically.', + 'The exact string to find and replace.', }, - 'parent_path': { - 'type': - 'string', - 'description': - 'The relative parent path to search in (optional, defaults to root)', - }, - 'file_pattern': { - 'type': - 'string', + 'new_string': { + 'type': 'string', 'description': - 'Wildcard pattern for file names, e.g., "*.py", "*.js", "test_*.py" ' - '(default: "*" for all files)', + 'The string to replace it with.', }, - 'context_lines': { + 'replace_all': { 'type': - 'integer', + 'boolean', 'description': - 'Number of lines before and after the match to include (default: 2)', + 'If true, replace all occurrences. Default is false (replace only the first).', }, }, - 'required': ['content'], + 'required': ['path', 'old_string', 'new_string'], 'additionalProperties': False }), Tool( - tool_name='search_file_name', + tool_name='grep', server_name='file_system', description= - 'Search for files by name using regex pattern matching. ' - 'Supports both regex patterns and simple substring matching. ' - 'If the file parameter is a valid regex pattern, it will be used for regex matching; ' - 'otherwise, falls back to substring matching. ' - 'The parent_path can also be a regex pattern to filter directories.', + ('Search file contents under the workspace using ripgrep when available, ' + 'otherwise a safe Python scan. Paths must stay under the configured output/workspace roots. ' + 'Read-only.'), parameters={ 'type': 'object', 'properties': { - 'file': { + 'pattern': { 'type': 'string', 'description': - 'The filename pattern to search for (supports regex, e.g., r"\\.js$" for .js files, ' - 'or "service" for substring match).', + 'Regular expression (Rust regex if rg is used).', }, - 'parent_path': { + 'path': { 'type': 'string', 'description': - 'The relative parent path to search in (supports regex for directory filtering, ' - 'e.g., r"backend.*" to match backend-related directories). ' - 'Defaults to root if not specified.', + 'Directory or file to search (relative to output_dir if not absolute). Default ".".', }, - }, - 'required': ['file'], - 'additionalProperties': False - }), - Tool( - tool_name='replace_file_lines', - server_name='file_system', - description= - 'Replace specific line ranges in a file. Supports inserting at beginning ' - '(start_line=0) or end (start_line=-1). Line numbers are 1-based and inclusive on both ends.\n\n' - 'IMPORTANT — Line-number shift after each call. Every replacement changes the total line count, ' - 'which invalidates ALL line numbers after the replaced range. If you need to make multiple replacements in the same file:\n' - '- Option A (recommended): Work from BOTTOM to TOP — edit the largest line numbers first so earlier line numbers remain valid.\n' - '- Option B: Re-search after each replacement to get updated line numbers before the next replacement.\n' - '- Option C: Pre-calculate the cumulative offset — each replacement shifts subsequent lines by (new_content_lines - replaced_lines).\n' - 'NEVER call this tool multiple times in parallel on the same file — the concurrent line-number ' - 'shifts will corrupt the file. Always call sequentially.\n', - parameters={ - 'type': 'object', - 'properties': { - 'path': { + 'glob': { 'type': 'string', 'description': - 'The relative path of the file to modify', + 'Optional glob filter for files, e.g. "*.py"', }, - 'content': { - 'type': 'string', + 'output_mode': { + 'type': + 'string', + 'enum': + ['content', 'files_with_matches', 'count'], 'description': - 'The new content to insert/replace', + 'content: matching lines; files_with_matches: paths only; count: per-file counts', }, - 'start_line': { + 'head_limit': { 'type': 'integer', 'description': - 'Start line number (1-based, inclusive). Use 0 to insert at beginning, ' - '-1 to append at end', + 'Max lines (content) or paths/count entries to return', }, - 'end_line': { + 'offset': { 'type': 'integer', 'description': - 'End line number (1-based, inclusive). Required unless start_line is 0 or -1', + 'Skip first N lines/entries after collect', + }, + 'case_insensitive': { + 'type': 'boolean', + 'description': 'Case-insensitive search', }, }, - 'required': ['path', 'content', 'start_line'], - 'additionalProperties': False - }), + 'required': ['pattern'], + 'additionalProperties': False, + }, + ), Tool( - tool_name='replace_file_contents', + tool_name='glob', server_name='file_system', description= - 'Replace exact content in a file without using line numbers. ' - 'You must provide:\n' - '[Required]path: The relative path of modified file.\n' - '[Required]source: The old content to be replaced.\n' - '[Required]target: The new content to replace the `source`.\n' - '[Required]occurrence: Which occurrence to replace (1-based).\n' - 'Do not miss any of these arguments!\n\n' - 'IMPORTANT:\n' - '- `source` must match the file content EXACTLY — including punctuation style ' - '(e.g., Chinese "、" vs English ","), whitespace, line breaks, and Unicode characters.', + ('List files under a workspace directory matching a glob pattern ' + '(e.g. "**/*.py", "*.md"). Read-only; results are capped.' + ), parameters={ 'type': 'object', 'properties': { - 'path': { - 'type': - 'string', - 'description': - 'The relative path of the file to modify', - }, - 'source': { - 'type': - 'string', - 'description': - 'The exact content to find and replace. Must match the file content ' - 'EXACTLY including all whitespace, punctuation, and line breaks. ', - }, - 'target': { + 'pattern': { 'type': 'string', - 'description': - 'The new content to replace with', + 'description': 'Glob pattern relative to path', }, - 'occurrence': { + 'path': { 'type': - 'integer', + 'string', 'description': - 'Which occurrence to replace (1-based). Default is 1 (first occurrence). ' - 'Use -1 to replace all occurrences.', + 'Base directory (relative to output_dir if not absolute).', }, }, - 'required': ['path', 'source', 'target', 'occurrence'], - 'additionalProperties': False - }), + 'required': ['pattern'], + 'additionalProperties': False, + }, + ), ] } return tools @@ -429,25 +403,339 @@ async def call_tool(self, server_name: str, *, tool_name: str, tool_args: dict) -> str: return await getattr(self, tool_name)(**tool_args) - async def create_directory(self, path: Optional[str] = None) -> str: - """Create a directory - - Args: - path(`str`): The relative directory path to create, a prefix dir will be automatically concatenated. + async def grep( + self, + pattern: str, + path: str = '.', + glob: Optional[str] = None, + output_mode: str = 'files_with_matches', + head_limit: Optional[int] = None, + offset: Optional[int] = None, + case_insensitive: bool = False, + ) -> str: + call_id = f'grep-{pattern[:40]}' + head_limit = ( + head_limit if head_limit is not None else self._default_grep_head) + offset = offset or 0 + path = path or '.' + try: + root = self._fs_policy.resolve_under_roots(path) + except WorkspacePolicyError as e: + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + if pattern is None or (isinstance(pattern, str) + and not pattern.strip()): + return json.dumps( + { + 'success': False, + 'error': 'grep requires a non-empty pattern string.', + }, + indent=2, + ) + if isinstance(pattern, str) and ('\n' in pattern or '\r' in pattern): + return json.dumps( + { + 'success': + False, + 'error': + ('grep pattern must not contain raw newline characters; ' + 'ripgrep rejects them unless multiline mode is enabled server-side. ' + 'Use several single-line patterns, escape newlines as needed, ' + 'or search with read_file for fixed multi-line text.'), + }, + indent=2, + ) - Returns: - or error message. - """ + lines: List[str] = [] try: - if not path: - path = self.output_dir + rg = shutil.which('rg') + if rg and root.is_file(): + lines = await self._grep_rg_file(rg, pattern, root, + case_insensitive, output_mode, + head_limit, offset, glob) + elif rg and root.is_dir(): + lines = await self._grep_rg_dir(rg, pattern, root, + case_insensitive, output_mode, + head_limit, offset, glob) else: - path = os.path.join(self.output_dir, path) - os.makedirs(path, exist_ok=True) - return f'Directory: <{path or "root path"}> was created.' + lines = self._grep_python( + pattern, + root, + glob, + output_mode, + head_limit, + offset, + case_insensitive, + ) except Exception as e: - return f'Create directory <{path or "root path"}> failed, error: ' + str( - e) + err = str(e) + # Expected user/tooling failures (bad regex, rg rules) — log without traceback noise. + _quiet = ('rg:' in err or 'exited' in err.lower() + or 'regex' in err.lower() or 'pattern' in err.lower()) + logger.warning('grep failed: %s', e, exc_info=not _quiet) + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + text = '\n'.join(lines) + packed = self._fs_artifacts.pack_text_result( + tool_name='grep', + call_id=call_id, + stdout=text, + stderr='', + extra={ + 'success': True, + 'output_mode': output_mode, + 'num_lines': len(lines), + }, + ) + return json.dumps(packed, ensure_ascii=False, indent=2, default=str) + + async def _grep_rg_file( + self, + rg: str, + pattern: str, + file_path: Path, + case_insensitive: bool, + output_mode: str, + head_limit: int, + offset: int, + glob_pat: Optional[str], + ) -> List[str]: + args = [rg, '--no-heading', '--color', 'never'] + if case_insensitive: + args.append('-i') + if glob_pat: + args.extend(['--glob', glob_pat]) + if output_mode == 'files_with_matches': + args.extend(['-l', pattern, str(file_path)]) + elif output_mode == 'count': + args.extend(['-c', pattern, str(file_path)]) + else: + args.extend(['-n', pattern, str(file_path)]) + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(self._fs_policy.workspace_root), + ) + out_b, err_b = await asyncio.wait_for( + proc.communicate(), timeout=self._grep_timeout) + out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') + err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') + if proc.returncode not in (0, 1): + raise RuntimeError(err or f'rg exited {proc.returncode}') + lines = [ln for ln in out.split('\n') if ln] if out else [] + return _apply_offset_limit(lines, offset, head_limit) + + async def _grep_rg_dir( + self, + rg: str, + pattern: str, + root: Path, + case_insensitive: bool, + output_mode: str, + head_limit: int, + offset: int, + glob_pat: Optional[str], + ) -> List[str]: + args = [rg, '--no-heading', '--color', 'never'] + if case_insensitive: + args.append('-i') + if glob_pat: + args.extend(['--glob', glob_pat]) + if output_mode == 'files_with_matches': + args.extend(['-l', pattern, str(root)]) + elif output_mode == 'count': + args.extend(['--count-matches', pattern, str(root)]) + else: + args.extend(['-n', pattern, str(root)]) + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(self._fs_policy.workspace_root), + ) + out_b, err_b = await asyncio.wait_for( + proc.communicate(), timeout=self._grep_timeout) + out = (out_b or b'').decode('utf-8', errors='replace').strip('\n') + err = (err_b or b'').decode('utf-8', errors='replace').strip('\n') + if proc.returncode not in (0, 1): + raise RuntimeError(err or f'rg exited {proc.returncode}') + lines = [ln for ln in out.split('\n') if ln] if out else [] + return _apply_offset_limit(lines, offset, head_limit) + + def _grep_python( + self, + pattern: str, + root: Path, + glob_pat: Optional[str], + output_mode: str, + head_limit: int, + offset: int, + case_insensitive: bool, + ) -> List[str]: + flags = re.IGNORECASE if case_insensitive else 0 + try: + rx = re.compile(pattern, flags) + except re.error as e: + return [f'[error] invalid regex: {e}'] + lines_out: List[str] = [] + counts: Dict[str, int] = {} + + def consider_file(fp: Path) -> bool: + if glob_pat: + rel = str(fp.relative_to(root)) if root.is_dir() else fp.name + if not fnmatch.fnmatch(fp.name, + glob_pat) and not fnmatch.fnmatch( + rel, glob_pat): + return False + suf = fp.suffix.lower() + if suf not in _TEXT_SUFFIXES and fp.suffix == '': + if fp.name not in ('Dockerfile', 'Makefile', 'README'): + return False + return fp.is_file() + + files: List[Path] = [] + if root.is_file(): + files = [root] + else: + for fp in _walk_files_limited(root, self._fs_policy.deny_globs, + 50_000): + if consider_file(fp): + files.append(fp) + + for fp in files: + try: + text = fp.read_text(encoding='utf-8', errors='replace') + except OSError: + continue + rel = str(fp.relative_to( + self._fs_policy.workspace_root)) if _is_relative( + fp, self._fs_policy.workspace_root) else str(fp) + if output_mode == 'files_with_matches': + if rx.search(text): + lines_out.append(rel) + elif output_mode == 'count': + n = len(rx.findall(text)) + if n: + counts[rel] = n + else: + for i, line in enumerate(text.splitlines(), start=1): + if rx.search(line): + lines_out.append(f'{rel}:{i}:{line}') + if len(lines_out) >= head_limit + offset + 5000: + break + + if output_mode == 'count': + lines_out = [f'{k}:{v}' for k, v in sorted(counts.items())] + return _apply_offset_limit(lines_out, offset, head_limit) + + async def glob(self, pattern: str, path: str = '') -> str: + call_id = f'glob-{pattern[:40]}' + try: + base = self._fs_policy.resolve_under_roots(path or '.') + except WorkspacePolicyError as e: + return json.dumps({'success': False, 'error': str(e)}, indent=2) + + if not base.is_dir(): + return json.dumps( + { + 'success': False, + 'error': f'Not a directory: {path}', + }, + indent=2, + ) + + matches: List[str] = [] + truncated = False + deny = self._fs_policy.deny_globs + + try: + for p in sorted(base.glob(pattern)): + if not p.is_file(): + continue + rp = p.resolve() + if not self._fs_policy.path_is_allowed(rp): + continue + if _is_denied_path(rp, base, deny): + continue + rel = str(p.relative_to( + self._fs_policy.workspace_root)) if _is_relative( + p, self._fs_policy.workspace_root) else str(p) + matches.append(rel) + if len(matches) >= self._glob_max_files: + truncated = True + break + except ValueError: + return json.dumps( + { + 'success': False, + 'error': 'Invalid glob pattern', + }, + indent=2, + ) + + text = json.dumps( + { + 'success': True, + 'num_files': len(matches), + 'filenames': matches, + 'truncated': truncated, + }, + ensure_ascii=False, + indent=2, + ) + packed = self._fs_artifacts.pack_text_result( + tool_name='glob', + call_id=call_id, + stdout=text, + stderr='', + extra={'success': True}, + ) + return json.dumps(packed, ensure_ascii=False, indent=2, default=str) + + def _normalize_quotes(self, s: str) -> str: + for curly, straight in self.CURLY_QUOTE_MAP.items(): + s = s.replace(curly, straight) + return s + + def _preserve_quote_style(self, old_string: str, actual_old: str, + new_string: str) -> str: + """If old_string matched via quote normalization, apply the same curly quotes to new_string.""" + if old_string == actual_old: + return new_string + has_double = any(c in actual_old for c in '\u201c\u201d') + has_single = any(c in actual_old for c in '\u2018\u2019') + result = new_string + if has_double: + out, chars = [], list(result) + for i, ch in enumerate(chars): + if ch == '"': + prev = chars[i - 1] if i > 0 else None + opening = prev is None or prev in ' \t\n\r([{' + out.append('\u201c' if opening else '\u201d') + else: + out.append(ch) + result = ''.join(out) + if has_single: + out, chars = [], list(result) + for i, ch in enumerate(chars): + if ch == "'": + prev = chars[i - 1] if i > 0 else None + nxt = chars[i + 1] if i < len(chars) - 1 else None + # apostrophe in contraction → right single quote + if prev and nxt and prev.isalpha() and nxt.isalpha(): + out.append('\u2019') + else: + opening = prev is None or prev in ' \t\n\r([{' + out.append('\u2018' if opening else '\u2019') + else: + out.append(ch) + result = ''.join(out) + return result + + @staticmethod + def _strip_trailing_whitespace(s: str) -> str: + return '\n'.join(line.rstrip() for line in s.split('\n')) async def write_file(self, path: str, content: str): """Write content to a file. @@ -463,175 +751,18 @@ async def write_file(self, path: str, content: str): if not os.path.exists(self.output_dir): os.makedirs(self.output_dir, exist_ok=True) original_path = path # Preserve original path for error messages - path = self.get_real_path(path) - if path is None: + real_path = self.get_real_path(path) + if real_path is None: return f'<{original_path}> is out of the valid project path: {self.output_dir}' - dirname = os.path.dirname(path) + dirname = os.path.dirname(real_path) if dirname: - os.makedirs( - os.path.join(self.output_dir, dirname), exist_ok=True) - with open(os.path.join(self.output_dir, path), 'w') as f: + os.makedirs(dirname, exist_ok=True) + with open(real_path, 'w', encoding='utf-8') as f: f.write(content) return f'Save file <{path}> successfully.' except Exception as e: return f'Write file <{path}> failed, error: ' + str(e) - async def replace_file_contents(self, - path: str, - source: str = None, - target: str = None, - occurrence: int = 1): - """Replace exact content in a file without using line numbers. - - This method is safer for parallel operations as it doesn't rely on line numbers - that might change when multiple agents modify the same file concurrently. - - Args: - path(str): The relative file path to modify - source(str): The exact content to find and replace (must match exactly including whitespace) - target(str): The new content to replace with - occurrence(int): Which occurrence to replace (1-based). Use -1 to replace all occurrences. - Default is 1 (first occurrence). - - Returns: - Success or error message. - """ - try: - if not source: - return f'Error: You MUST provide the `source` parameter to be replaced with the `target`, but got {source}.' - if target is None: - return f'Error: You MUST provide the `target` parameter to replace the `source`, but got {target}.' - target_path_real = self.get_real_path(path) - if target_path_real is None: - return f'<{path}> is out of the valid project path: {self.output_dir}' - - # Read file content - if not os.path.exists(target_path_real): - return f'Error: File <{path}> does not exist' - - with open(target_path_real, 'r', encoding='utf-8') as f: - file_content = f.read() - - # Check if source exists - if source not in file_content: - return ( - f'Error: Could not find the exact content to replace in <{path}>. ' - f'Make sure the content matches exactly including all whitespace.' - ) - - # Count occurrences - count = file_content.count(source) - - # Replace based on occurrence parameter - if occurrence == -1: - # Replace all occurrences - updated_content = file_content.replace(source, target) - operation_msg = f'Replaced all {count} occurrence(s)' - elif occurrence < 1: - return f'Error: occurrence must be >= 1 or -1 (for all), got {occurrence}' - elif occurrence > count: - return f'Error: occurrence {occurrence} exceeds total occurrences ({count}) of the content' - else: - # Replace specific occurrence - parts = file_content.split(source, occurrence) - if len(parts) <= occurrence: - return f'Error: Could not find occurrence {occurrence} of the content' - # Rejoin: first (occurrence-1) parts with source, then target, then the rest - updated_content = source.join( - parts[:occurrence]) + target + source.join( - parts[occurrence:]) - operation_msg = f'Replaced occurrence {occurrence} of {count}' - - # Write back to file - with open(target_path_real, 'w', encoding='utf-8') as f: - f.write(updated_content) - - return f'{operation_msg} in file <{path}> successfully.' - - except Exception as e: - return f'Replace content in file <{path}> failed, error: ' + str(e) - - async def replace_file_lines(self, - path: str, - content: str, - start_line: int, - end_line: int = None): - """Replace specific line ranges in a file. - - Args: - path(str): The relative file path to modify, a prefix dir will be automatically concatenated. - content(str): The new content to insert/replace - start_line(int): Start line number (1-based, inclusive). Use 0 to insert at beginning, -1 to append at end - end_line(int): End line number (1-based, inclusive). Optional for start_line=0 or -1 - - Returns: - Success or error message. - """ - try: - target_path_real = self.get_real_path(path) - if target_path_real is None: - return f'<{path}> is out of the valid project path: {self.output_dir}' - file_path = target_path_real - # Read existing file content - if os.path.exists(file_path): - with open(file_path, 'r', encoding='utf-8') as f: - lines = f.readlines() - else: - # If file doesn't exist, create it - dirname = os.path.dirname(file_path) - if dirname: - os.makedirs(dirname, exist_ok=True) - lines = [] - - total_lines = len(lines) - - # Ensure content ends with newline if it doesn't already - if content and not content.endswith('\n'): - content += '\n' - - # Handle special cases - if start_line == 0: - # Insert at beginning - new_lines = [content] + lines - operation = 'Inserted at beginning' - elif start_line == -1: - # Append at end - new_lines = lines + [content] - operation = 'Appended at end' - else: - # Replace range (1-based, inclusive) - if end_line is None: - return 'Error: end_line is required when start_line is not 0 or -1' - - if start_line < 1 or start_line > total_lines + 1: - return f'Error: start_line {start_line} is out of range (file has {total_lines} lines)' - - if end_line < start_line: - return f'Error: end_line {end_line} must be >= start_line {start_line}' - - # Convert to 0-based indices - start_idx = start_line - 1 - # end_line is inclusive (1-based), so we keep lines from end_line onwards (0-based) - end_idx = end_line - # Lines to keep start from index end_line (which is the line after end_line in 1-based) - - new_lines = lines[:start_idx] + [content] + lines[end_idx:] - operation = f'Replaced lines {start_line}-{end_line}' - - # Write back to file - with open(file_path, 'w', encoding='utf-8') as f: - f.writelines(new_lines) - - target = '\n'.join(new_lines).split('\n') - return ( - f'{operation} in file <{path}> completed successfully. The updated file now has {len(target)} lines. ' - 'WARNING: All line numbers after the replaced range may have shifted. ' - 'If you need to make another line-based replacement in this file, keep this in mind.' - ) - - except Exception as e: - return f'Replace lines in file <{path}> failed, error: ' + str(e) - def get_real_path(self, path): # Check if path is absolute or already starts with output_dir if os.path.isabs(path): @@ -655,7 +786,151 @@ def get_real_path(self, path): else: return target_path_real - async def read_abbreviation_file(self, paths: list[str]): + def _normalize_read_paths(self, paths, path) -> List[str]: + """Accept `paths` (list or lone string) or legacy `path` kwarg.""" + out: List[str] = [] + if paths is not None: + if isinstance(paths, str) and paths.strip(): + out = [paths.strip()] + elif isinstance(paths, list): + out = [ + p.strip() for p in paths + if isinstance(p, str) and p.strip() + ] + if not out and path is not None and isinstance(path, + str) and path.strip(): + out = [path.strip()] + return out + + async def read_file(self, + paths: Optional[List[str]] = None, + path: Optional[str] = None, + offset: int = None, + limit: int = None, + abbreviate: bool = False): + """Read the content of file(s). + + Args: + paths: List of relative file path(s) to read. + path: Single path (alias); used when `paths` is omitted. + offset: Line number to start reading from (1-based). Only effective for a single file. + limit: Number of lines to read. Only effective for a single file. + abbreviate: If True, return an LLM-generated summary instead of raw content. + + Returns: + Dictionary mapping file path(s) to their content or error messages. + """ + paths = self._normalize_read_paths(paths, path) + if not paths: + return json.dumps( + { + 'success': + False, + 'error': + ('read_file requires `paths` (list of strings) or `path` (single string). ' + 'Example: {"paths": ["a.md"]} or {"path": "a.md"}.'), + }, + indent=2, + ) + if abbreviate: + return await self._read_files_abbreviated(paths) + + results = {} + use_line_range = len(paths) == 1 and (offset is not None + or limit is not None) + + for path in paths: + try: + target_path_real = self.get_real_path(path) + if target_path_real is None: + results[path] = ( + f'Access denied: Reading file <{path}> outside output directory is not allowed. ' + f'Set allow_read_all_files=true in config to enable.') + continue + + ext = os.path.splitext(path)[1].lstrip('.').lower() + + # --- Image files --- + if ext in self.IMAGE_EXTENSIONS: + with open(target_path_real, 'rb') as f: + raw = f.read() + media_type = f'image/{ext}' if ext != 'jpg' else 'image/jpeg' + results[path] = { + 'type': 'image', + 'media_type': media_type, + 'base64': base64.b64encode(raw).decode('ascii'), + } + continue + + # --- Text files --- + file_size = os.path.getsize(target_path_real) + if file_size > self.MAX_READ_BYTES and not use_line_range: + results[path] = ( + f'Error: File <{path}> is too large ({file_size} bytes). ' + f'Use offset and limit to read specific portions.') + continue + + # Dedup: return stub if file unchanged since last read + mtime = os.path.getmtime(target_path_real) + cached = self._read_cache.get(target_path_real) + if (cached and cached['mtime'] == mtime + and cached['offset'] == offset + and cached['limit'] == limit): + results[path] = { + 'type': 'file_unchanged', + 'message': 'File has not changed since last read.', + } + continue + + with open(target_path_real, 'rb') as f: + raw_bytes = f.read() + + try: + content = raw_bytes.decode('utf-8') + except UnicodeDecodeError: + results[path] = ( + f'Error: File <{path}> appears to be binary. ' + f'Only text and image files are supported.') + continue + + # Normalize line endings + content = content.replace('\r\n', '\n') + lines = content.splitlines(keepends=True) + total_lines = len(lines) + + if use_line_range: + actual_start = max(1, offset) if offset is not None else 1 + actual_end = min( + actual_start + limit + - 1, total_lines) if limit is not None else total_lines + + if actual_start > total_lines: + results[ + path] = f'Error: offset {offset} exceeds file length ({total_lines} lines)' + continue + selected = lines[actual_start - 1:actual_end] + start_lineno = actual_start + else: + selected = lines + start_lineno = 1 + + results[path] = ''.join(f'{start_lineno + i}\t{line}' + for i, line in enumerate(selected)) + + # Update dedup cache + self._read_cache[target_path_real] = { + 'mtime': mtime, + 'offset': offset, + 'limit': limit, + } + + except FileNotFoundError: + results[path] = f'Read file <{path}> failed: FileNotFound' + except Exception as e: + results[path] = f'Read file <{path}> failed, error: ' + str(e) + return json.dumps(results, indent=2, ensure_ascii=False) + + async def _read_files_abbreviated(self, paths: list[str]) -> str: results = {} def process_file(path): @@ -666,14 +941,15 @@ def process_file(path): index_file = os.path.join(self.index_dir, path.strip(os.sep)) if os.path.exists(index_file): - with open(index_file, 'r', encoding='utf-8') as f: - return path, f.read() + src_mtime = os.path.getmtime(target_path_real) + idx_mtime = os.path.getmtime(index_file) + if idx_mtime >= src_mtime: + with open(index_file, 'r', encoding='utf-8') as f: + return path, f.read() - # Read file content with open(target_path_real, 'r', encoding='utf-8') as f: content = f.read() - # Use LLM to generate abbreviation messages = [ Message(role='system', content=self.system), Message( @@ -691,7 +967,6 @@ def process_file(path): except Exception as e: return path, f'Process file <{path}> failed, error: ' + str(e) - # Use thread pool for parallel LLM API calls with ThreadPoolExecutor(max_workers=4) as executor: future_to_path = { executor.submit(process_file, p): p @@ -703,347 +978,172 @@ def process_file(path): return json.dumps(results, indent=2, ensure_ascii=False) - async def read_file(self, - paths: list[str], - start_line: int = 0, - end_line: int = None): - """Read the content of file(s). + async def edit_file(self, + path: str = None, + old_string: str = None, + new_string: str = None, + replace_all: bool = False): + """Edit a file by replacing an exact string with new content. Args: - paths(`list[str]`): List of relative file path(s) to read, a prefix dir will be automatically concatenated. - start_line(int): Start line number (1-based, inclusive). Only effective when paths has exactly one element. - 0 means from the beginning. - end_line(int): End line number (1-based, inclusive). Only effective when paths has exactly one element. - None means to the end. + path: The relative file path to edit. + old_string: The exact string to find and replace. + new_string: The replacement string. + replace_all: If True, replace all occurrences. Default replaces only the first. Returns: - Dictionary mapping file path(s) to their content or error messages. + Success or error message. """ - results = {} - # Line range is only effective when reading a single file - use_line_range = len(paths) == 1 and (start_line > 0 - or end_line is not None) - - for path in paths: - try: - target_path_real = self.get_real_path(path) - if target_path_real is None: - results[path] = ( - f'Access denied: Reading file <{path}> outside output directory is not allowed. ' - f'Set allow_read_all_files=true in config to enable.') - continue + try: + if old_string is None: + return 'Error: `old_string` is required.' + if new_string is None: + return 'Error: `new_string` is required.' - with open(target_path_real, 'r') as f: - if use_line_range: - # Read specific line range - lines = f.readlines() - total_lines = len(lines) - - # Validate and adjust line numbers (1-based) - actual_start = max(1, - start_line) if start_line > 0 else 1 - actual_end = min( - end_line, total_lines - ) if end_line is not None else total_lines - - if actual_start > total_lines: - results[ - path] = f'Error: start_line {start_line} exceeds file length ({total_lines} lines)' - elif actual_start > actual_end: - results[ - path] = f'Error: start_line {actual_start} > end_line {actual_end}' - else: - # Convert to 0-based index, end_line is inclusive - selected_lines = lines[actual_start - 1:actual_end] - results[path] = ''.join(selected_lines) - else: - # Read entire file - results[path] = f.read() - except FileNotFoundError: - results[path] = f'Read file <{path}> failed: FileNotFound' - except Exception as e: - results[path] = f'Read file <{path}> failed, error: ' + str(e) - return json.dumps(results, indent=2, ensure_ascii=False) + target_path_real = self.get_real_path(path) + if target_path_real is None: + return f'<{path}> is out of the valid project path: {self.output_dir}' - async def delete_file_or_dir(self, path: str): - """Delete a file or a directory. + # --- Special case: old_string="" --- + if old_string == '': + if not os.path.exists(target_path_real): + # Create new file + os.makedirs( + os.path.dirname(target_path_real), exist_ok=True) + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(new_string) + return f'Created file <{path}> successfully.' + with open(target_path_real, 'rb') as f: + existing = f.read() + try: + existing_text = existing.decode('utf-8') + except UnicodeDecodeError: + return f'Error: File <{path}> appears to be binary and cannot be edited as text.' + if existing_text.strip() != '': + return ( + 'Error: `old_string` is empty but the file already has content. ' + 'Use `write_file` for a full rewrite, or provide an `old_string` anchor to insert content.' + ) + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(new_string) + return f'Edit file <{path}> successfully (filled empty file).' - Args: - path(str): The file or directory to delete, a prefix dir will be automatically concatenated. + if not os.path.exists(target_path_real): + return f'Error: File <{path}> does not exist.' - Returns: - boolean - """ - abs_path = os.path.join(self.output_dir, path) - if os.path.exists(abs_path): + with open(target_path_real, 'rb') as f: + raw = f.read() try: - if os.path.isfile(abs_path): - os.remove(abs_path) - else: - shutil.rmtree(abs_path) - return f'Path deleted: <{path}>' - except Exception as e: - return f'Delete file <{path}> failed, error: ' + str(e) - else: - return f'Path not found: {path}' - - async def search_file_name(self, file: str = '', parent_path: str = ''): - """Search for files by name using regex pattern matching. + content = raw.decode('utf-8') + except UnicodeDecodeError: + return f'Error: File <{path}> appears to be binary and cannot be edited as text.' + + # Normalize line endings for matching + content = content.replace('\r\n', '\n') + old_string = old_string.replace('\r\n', '\n') + + # --- Fallback 1: exact match --- + actual_old = old_string if old_string in content else None + + # --- Fallback 2: quote normalization --- + if actual_old is None: + norm_old = self._normalize_quotes(old_string) + norm_content = self._normalize_quotes(content) + idx = norm_content.find(norm_old) + if idx != -1: + actual_old = content[idx:idx + len(old_string)] + + if actual_old is None: + return ( + f'Error: `old_string` not found in <{path}>. ' + f'Make sure it matches the file content exactly including whitespace.' + ) - Args: - file(str): File name pattern (supports regex). If it's a valid regex pattern, - it will be used for regex matching; otherwise, falls back to substring matching. - parent_path(str): Parent path pattern (supports regex for filtering directories). - Can be a simple path or a regex pattern to match directory paths. + count = content.count(actual_old) + if count > 1 and not replace_all: + return ( + f'Error: Found {count} occurrences of `old_string` in <{path}>. ' + f'Add more surrounding context to make it unique, or set replace_all=true.' + ) - Returns: - String containing all matched file paths - """ - parent_path = parent_path or '' - target_path_real = self.get_real_path(parent_path) - if target_path_real is None: - return f'<{parent_path}> is out of the valid project path: {self.output_dir}' - _parent_path = target_path_real - assert os.path.isdir( - _parent_path - ), f'Parent path <{parent_path}> does not exist, it should be a inner relative path of the project folder.' - - # Try to compile file pattern as regex - file_use_regex = False - file_pattern = None - if file: - try: - file_pattern = re.compile(file) - file_use_regex = True - except re.error: - file_use_regex = False - - # Try to compile parent_path filter as regex (optional) - path_use_regex = False - path_pattern = None - if parent_path: - try: - path_pattern = re.compile(parent_path) - path_use_regex = True - except re.error: - path_use_regex = False - - all_found_files = [] - for root, dirs, files in os.walk(_parent_path): - if path_use_regex and parent_path: - relative_root = os.path.relpath(root, self.output_dir) - if not path_pattern.search(relative_root): - continue + # Apply quote style preservation to new_string + actual_new = self._preserve_quote_style(old_string, actual_old, + new_string) - for filename in files: - if file: - if file_use_regex: - is_match = file_pattern.search(filename) is not None - else: - is_match = file in filename - else: - is_match = True # No filter, match all files + # --- Fallback 3: smart delete — strip trailing newline when deleting --- + if actual_new == '' and not actual_old.endswith( + '\n') and actual_old + '\n' in content: + actual_old = actual_old + '\n' - if is_match: - file_path = os.path.join(root, filename) - relative_path = os.path.relpath(file_path, self.output_dir) - all_found_files.append(relative_path) + # Strip trailing whitespace from new_string (skip markdown files) + is_markdown = path.lower().endswith(('.md', '.mdx')) + if not is_markdown: + actual_new = self._strip_trailing_whitespace(actual_new) - if not all_found_files: - return f'No files found matching pattern <{file or "*"}> in <{parent_path or "root"}>' + if replace_all: + updated = content.replace(actual_old, actual_new) + else: + updated = content.replace(actual_old, actual_new, 1) - all_found_files = '\n'.join(all_found_files) - return f'Found {len(all_found_files.splitlines())} file(s) matching <{file or "*"}>:\n{all_found_files}' + with open(target_path_real, 'w', encoding='utf-8') as f: + f.write(updated) - async def search_file_content(self, - content: str = None, - parent_path: str = '.', - file_pattern: str = '*', - context_lines: int = 2): - """Search for content in files using thread pool. - Supports both literal string matching and regex pattern matching automatically. + replaced = count if replace_all else 1 + return f'Edit file <{path}> successfully ({replaced} occurrence(s) replaced).' + except Exception as e: + return f'Edit file <{path}> failed, error: ' + str(e) - Args: - content(str): The content or regex pattern to search for (auto-detected) - parent_path(str): The relative parent path to search in - file_pattern(str): Wildcard pattern for file names (default: '*' for all files) - context_lines(int): Number of lines before and after the match to include (default: 2) - Returns: - String containing all matches with file path, line number, and context - """ - if parent_path.startswith('.' + os.sep): - parent_path = parent_path[len('.' + os.sep):] - if parent_path == '.': - parent_path = '' - target_path_real = self.get_real_path(parent_path) - if target_path_real is None: - return f'<{parent_path}> is out of the valid project path: {self.output_dir}' - _parent_path = target_path_real - assert os.path.isdir( - _parent_path - ), f'Parent path <{parent_path}> does not exist, it should be a inner relative path of the project folder.' - - if not content: - return 'Error: content parameter is required for search' - - # Try to compile as regex pattern, fallback to literal string matching - use_regex = False - pattern = None - try: - pattern = re.compile(content) - use_regex = True - except re.error: - # Not a valid regex, will use literal string matching - use_regex = False - - # Collect all files matching the pattern - files_to_search = [] - for root, dirs, files in os.walk(_parent_path): +def _apply_offset_limit(lines: List[str], offset: int, + head_limit: int) -> List[str]: + if offset: + lines = lines[offset:] + if head_limit and head_limit > 0: + lines = lines[:head_limit] + return lines + + +def _is_relative(path: Path, base: Path) -> bool: + try: + path.relative_to(base) + return True + except ValueError: + return False + + +def _is_denied_path(path: Path, root: Path, deny: tuple[str, ...]) -> bool: + if not deny: + return False + try: + rel = path.relative_to(root).as_posix() + except ValueError: + rel = path.as_posix() + for pat in deny: + if fnmatch.fnmatch(rel, pat): + return True + return False + + +def _walk_files_limited(root: Path, deny: tuple[str, ...], + max_files: int) -> List[Path]: + out: List[Path] = [] + for dirpath, dirnames, filenames in os.walk( + root, topdown=True, followlinks=False): + dp = Path(dirpath) + pruned = [] + for d in list(dirnames): + child = dp / d try: - test_dir = str(Path(root).relative_to(self.output_dir)) + rel = child.relative_to(root).as_posix() except ValueError: - test_dir = str(root) - if test_dir == '.': - test_dir = '' - if any(excluded_dir in root - for excluded_dir in self.EXCLUDED_DIRS): + rel = child.as_posix() + skip = any(fnmatch.fnmatch(rel, p) for p in deny) + if skip: continue - for filename in files: - # Skip excluded files - if filename.startswith( - self.EXCLUDED_FILE_PREFIXES) or test_dir.startswith( - self.EXCLUDED_FILE_PREFIXES): - continue - # Match file pattern - if fnmatch.fnmatch(filename, file_pattern): - files_to_search.append(os.path.join(root, filename)) - - if not files_to_search: - return f'No files matching pattern <{file_pattern}> found in <{parent_path or "root"}>' - - # Function to search in a single file - def search_in_file(file_path): - matches = [] - with open(file_path, 'r', encoding='utf-8') as f: - lines = f.readlines() - for line_num, line in enumerate(lines, start=1): - # Check for match: regex or literal string - is_match = False - if use_regex: - is_match = pattern.search(line) is not None - else: - is_match = content in line - - if is_match: - # Calculate context range - start_line = max(0, line_num - context_lines - 1) - end_line = min(len(lines), line_num + context_lines) - - # Extract context lines - context = [] - for i in range(start_line, end_line): - prefix = '> ' if i == line_num - 1 else ' ' - context.append( - f'{prefix}{i + 1:4d} | {lines[i].rstrip()}') - - relative_path = os.path.relpath( - file_path, self.output_dir) - matches.append({ - 'file': relative_path, - 'line': line_num, - 'context': '\n'.join(context) - }) - return matches - - # Use thread pool to search files in parallel - all_matches = [] - with ThreadPoolExecutor(max_workers=8) as executor: - future_to_file = { - executor.submit(search_in_file, f): f - for f in files_to_search - } - for future in as_completed(future_to_file): - matches = future.result() - all_matches.extend(matches) - - if not all_matches: - return f'No matches found for <{content}> in files matching <{file_pattern}>' - - # Format results - result_lines = [ - f'Found {len(all_matches)} match(es) for "{content}":\n' - ] - for match in all_matches: - result_lines.append( - f"File: {match['file']}, Line: {match['line']}") - result_lines.append(match['context']) - result_lines.append('') - - return '\n'.join(result_lines) - - async def list_files(self, path: str = None): - """List all files in a directory. - - Args: - path: The relative path to traverse, a prefix dir will be automatically concatenated. - - Returns: - The file names concatenated as a string - """ - file_paths = [] - if not path or path == '.': - path = self.output_dir - else: - path = os.path.join(self.output_dir, path) - if path.startswith('.' + os.sep): - path = path[len('.' + os.sep):] - try: - for root, dirs, files in os.walk(path): - try: - test_dir = str(Path(root).relative_to(self.output_dir)) - except ValueError: - test_dir = str(root) - if test_dir == '.': - test_dir = '' - for file in files: - # Skip excluded directories and files - root_exclude = any(excluded_dir in root - for excluded_dir in self.EXCLUDED_DIRS) - if root_exclude or file.startswith( - self.EXCLUDED_FILE_PREFIXES - ) or test_dir.startswith(self.EXCLUDED_FILE_PREFIXES): - continue - absolute_path = os.path.join(root, file) - relative_path = os.path.relpath(absolute_path, path) - file_paths.append(relative_path) - return '\n'.join(file_paths) or f'No files in path: {path}' - except Exception as e: - return f'List files of <{path or "root path"}> failed, error: ' + str( - e) - - @retry(max_attempts=MAX_CONTINUE_RUNS, delay=1.0) - async def edit_file(self, - path: str = None, - instructions: str = None, - code_edit: str = None): - try: - with open(os.path.join(self.output_dir, path), 'r') as f: - initial_code = f.read() - response = self.edit_client.chat.completions.create( - model=self.edit_file_config.diff_model, - messages=[{ - 'role': - 'user', - 'content': - (f'{instructions}\n' - f'{initial_code}\n' - f'{code_edit}') - }]) - merged_code = response.choices[0].message.content - - with open(os.path.join(self.output_dir, path), 'w') as f: - f.write(merged_code) - return f'Edit file <{path}> successfully.' - except Exception as e: - return f'Edit file <{path}> failed, error: ' + str(e) + pruned.append(d) + dirnames[:] = pruned + for name in filenames: + out.append(dp / name) + if len(out) >= max_files: + return out + return out diff --git a/ms_agent/tools/findata/akshare_source.py b/ms_agent/tools/findata/akshare_source.py index 6e68aaf41..b9fa3f7b4 100644 --- a/ms_agent/tools/findata/akshare_source.py +++ b/ms_agent/tools/findata/akshare_source.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import pandas as pd import re from typing import Any, Dict, List, Optional -import pandas as pd from ms_agent.tools.findata.data_source_base import (DataSourceError, FinancialDataSource, NoDataFoundError) diff --git a/ms_agent/tools/findata/baostock_source.py b/ms_agent/tools/findata/baostock_source.py index dd05b7a1f..9379b3ba2 100644 --- a/ms_agent/tools/findata/baostock_source.py +++ b/ms_agent/tools/findata/baostock_source.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import pandas as pd import threading from contextlib import contextmanager from copy import deepcopy from typing import Any, Dict, List, Optional -import pandas as pd from ms_agent.tools.findata.data_source_base import (DataSourceError, FinancialDataSource, NoDataFoundError) diff --git a/ms_agent/tools/findata/data_source_base.py b/ms_agent/tools/findata/data_source_base.py index d287aebdb..e78b9fd29 100644 --- a/ms_agent/tools/findata/data_source_base.py +++ b/ms_agent/tools/findata/data_source_base.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import pandas as pd from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -import pandas as pd - class DataSourceError(Exception): """Base data source error class""" diff --git a/ms_agent/tools/findata/findata_fetcher.py b/ms_agent/tools/findata/findata_fetcher.py index 996f28781..2b5afdcfd 100644 --- a/ms_agent/tools/findata/findata_fetcher.py +++ b/ms_agent/tools/findata/findata_fetcher.py @@ -1,14 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import json +import numpy as np +import pandas as pd from concurrent.futures import ThreadPoolExecutor from datetime import date, datetime from functools import partial +from omegaconf import DictConfig from pathlib import Path from typing import Any, Dict, Optional, Union -import json -import numpy as np -import pandas as pd from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.tools.findata.akshare_source import AKShareDataSource @@ -19,7 +20,6 @@ from ms_agent.tools.findata.hybrid_source import HybridDataSource from ms_agent.utils import get_logger from ms_agent.utils.rate_limiter import AdaptiveRateLimiter, RateLimiter -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/tools/findata/hybrid_source.py b/ms_agent/tools/findata/hybrid_source.py index 79380fd14..66b246c8e 100644 --- a/ms_agent/tools/findata/hybrid_source.py +++ b/ms_agent/tools/findata/hybrid_source.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import pandas as pd import re from typing import Any, Callable, Dict, List, Optional -import pandas as pd from ms_agent.tools.findata.akshare_source import AKShareDataSource from ms_agent.tools.findata.baostock_source import BaoStockDataSource from ms_agent.tools.findata.data_source_base import (DataSourceError, diff --git a/ms_agent/tools/image_generator/ds_image_gen.py b/ms_agent/tools/image_generator/ds_image_gen.py index b7286b218..ea83d1d06 100644 --- a/ms_agent/tools/image_generator/ds_image_gen.py +++ b/ms_agent/tools/image_generator/ds_image_gen.py @@ -1,7 +1,6 @@ import os import uuid from io import BytesIO - from PIL import Image diff --git a/ms_agent/tools/image_generator/ms_image_gen.py b/ms_agent/tools/image_generator/ms_image_gen.py index b121458e6..b61b6d7af 100644 --- a/ms_agent/tools/image_generator/ms_image_gen.py +++ b/ms_agent/tools/image_generator/ms_image_gen.py @@ -1,9 +1,8 @@ import asyncio +import json import os import uuid from io import BytesIO - -import json from PIL import Image diff --git a/ms_agent/tools/jina_reader.py b/ms_agent/tools/jina_reader.py index d8962f617..f60b39cd8 100644 --- a/ms_agent/tools/jina_reader.py +++ b/ms_agent/tools/jina_reader.py @@ -1,13 +1,21 @@ import asyncio +import html as html_module import random +import re import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from urllib.error import HTTPError, URLError -from urllib.parse import quote +from urllib.parse import quote, urlparse from urllib.request import Request, urlopen +from ms_agent.tools.fetch_playwright_fallback import ( + looks_like_spa_shell_html, try_playwright_inner_text) +from ms_agent.utils.logger import get_logger + +logger = get_logger() + DEFAULT_HEADERS: Dict[str, str] = { 'User-Agent': 'Mozilla/5.0 (compatible; ms-agent/1.0; +https://example.com)', @@ -15,16 +23,36 @@ 'Accept-Language': 'en-US,en;q=0.9', } +# Cap body size for direct HTTP fallback (same order of magnitude as MAX_FETCH_CHARS). +_MAX_DIRECT_RESPONSE_BYTES = 10 * 1024 * 1024 + +_DIRECT_FETCH_HEADERS: Dict[str, str] = { + 'User-Agent': DEFAULT_HEADERS['User-Agent'], + 'Accept': + 'text/html,application/xhtml+xml,application/xml;q=0.9,text/plain;q=0.8,*/*;q=0.7', + 'Accept-Language': DEFAULT_HEADERS['Accept-Language'], +} + @dataclass class JinaReaderConfig: base_endpoint: str = 'https://r.jina.ai/' - timeout: float = 30.0 + timeout: float = 45.0 retries: int = 3 backoff_base: float = 0.8 backoff_max: float = 8.0 headers: Dict[str, str] = field(default_factory=lambda: DEFAULT_HEADERS.copy()) + # When Jina Reader returns empty after retries, try HTTP GET on the target URL. + direct_fetch_fallback: bool = True + # Tier 2 (urllib): shorter than Jina timeout — fail fast on slow origins. + direct_fetch_timeout: float = 15.0 + # Tier 3: headless Chromium when direct body is empty/short or looks like a JS shell. + playwright_fetch_fallback: bool = True + playwright_retry_min_chars: int = 400 + playwright_timeout_ms: int = 30_000 + # After domcontentloaded, brief wait for client hydration (lower = faster). + playwright_settle_ms: int = 350 def _build_reader_url(target_url: str, base_endpoint: str) -> str: @@ -50,10 +78,85 @@ def _postprocess_text(raw_text: str) -> str: return text.strip() -def fetch_single_text(url: str, config: JinaReaderConfig) -> str: +def _is_direct_http_allowed(url: str) -> bool: + try: + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https'): + return False + if not parsed.netloc: + return False + return True + except Exception: + return False + + +def _html_to_plaintext(html: str) -> str: + """Best-effort HTML → text without extra dependencies.""" + text = re.sub(r'(?is)]*>.*?', ' ', html) + text = re.sub(r'(?is)]*>.*?', ' ', text) + text = re.sub(r'(?is)]*>.*?', ' ', text) + text = re.sub( + r'(?i)', + '\n', + text, + ) + text = re.sub(r'', '\n', text, flags=re.IGNORECASE) + text = re.sub(r'<[^>]+>', ' ', text) + text = html_module.unescape(text) + text = re.sub(r'[ \t\f\v]+', ' ', text) + text = re.sub(r'\n{3,}', '\n\n', text) + return text.strip() + + +# Snippet size for SPA heuristics (avoid holding multi‑MB strings in memory). +_DIRECT_HTML_HEURISTIC_CAP = 120_000 + + +def _fetch_direct_http_pair(url: str, timeout: float) -> Tuple[str, str]: """ - Synchronous fetch of a single URL via Jina Reader with retry/backoff and postprocessing. + Fetch the target URL over HTTP(S) without Jina. + + Returns: + (plaintext, raw_html_snippet) — ``raw_html_snippet`` is non-empty only when + the response was treated as HTML (used for shell / length heuristics). """ + if not _is_direct_http_allowed(url): + return '', '' + try: + req = Request(url, headers=_DIRECT_FETCH_HEADERS) + with urlopen(req, timeout=timeout) as resp: + raw = resp.read(_MAX_DIRECT_RESPONSE_BYTES + 1) + if len(raw) > _MAX_DIRECT_RESPONSE_BYTES: + raw = raw[:_MAX_DIRECT_RESPONSE_BYTES] + charset = resp.headers.get_content_charset() or 'utf-8' + content_type = (resp.headers.get('Content-Type') or '').lower() + content_type_main = content_type.split(';')[0].strip() + text = raw.decode(charset, errors='replace') + if 'html' in content_type_main or text.lstrip().lower().startswith( + ' bool: + """Whether tier-3 headless fetch is worth attempting.""" + p = plain.strip() + if raw_html: + if looks_like_spa_shell_html(raw_html): + return True + if len(p) < min_chars: + return True + return False + return not bool(p) + + +def _fetch_via_jina(url: str, config: JinaReaderConfig) -> str: + """Jina Reader only; returns empty string on failure.""" request_url = _build_reader_url(url, config.base_endpoint) attempt = 0 while True: @@ -62,10 +165,8 @@ def fetch_single_text(url: str, config: JinaReaderConfig) -> str: req = Request(request_url, headers=config.headers) with urlopen(req, timeout=config.timeout) as resp: data = resp.read() - return _postprocess_text( - data.decode('utf-8', errors='replace')) + return data.decode('utf-8', errors='replace') except HTTPError as e: - # Retry on 429 and 5xx, otherwise fail fast status = getattr(e, 'code', None) if status in (429, 500, 502, 503, 504) and attempt <= config.retries: @@ -84,7 +185,6 @@ def fetch_single_text(url: str, config: JinaReaderConfig) -> str: continue return '' except Exception: - # Unknown error; do not loop excessively if attempt <= config.retries: sleep_s = min(config.backoff_max, config.backoff_base * (2**(attempt - 1))) @@ -94,6 +194,62 @@ def fetch_single_text(url: str, config: JinaReaderConfig) -> str: return '' +def fetch_single_text_with_meta( + url: str, config: JinaReaderConfig) -> Tuple[str, Dict[str, Any]]: + """ + Tiered fetch: Jina Reader → direct HTTP → optional Playwright (empty / short / SPA shell). + + Returns: + (text, meta) where ``meta['content_source']`` is one of: + ``jina_reader`` | ``direct_http_fallback`` | ``playwright_fallback`` | ``none``. + """ + jina_raw = _fetch_via_jina(url, config) + jina_text = _postprocess_text(jina_raw) + if jina_text: + return jina_text, {'content_source': 'jina_reader'} + if not config.direct_fetch_fallback: + return '', {'content_source': 'none'} + d_timeout = ( + float(config.timeout) if float(config.direct_fetch_timeout or 0) <= 0 + else float(config.direct_fetch_timeout)) + direct_plain, raw_html = _fetch_direct_http_pair(url, d_timeout) + direct_text = _postprocess_text(direct_plain) + + try_playwright = ( + bool(config.playwright_fetch_fallback) and _is_direct_http_allowed(url) + and _should_try_playwright_after_direct( + direct_text, raw_html, config.playwright_retry_min_chars)) + + if try_playwright: + pw_text = _postprocess_text( + try_playwright_inner_text( + url, + int(config.playwright_timeout_ms), + settle_ms=int(config.playwright_settle_ms), + )) + if pw_text.strip(): + logger.info( + 'Using headless Chromium fallback after Jina/direct HTTP ' + f'(url prefix): {url[:80]}') + return pw_text, {'content_source': 'playwright_fallback'} + + if direct_text: + logger.info( + 'Jina Reader returned no body for URL; using direct HTTP fallback ' + f'(url prefix): {url[:80]}') + return direct_text, {'content_source': 'direct_http_fallback'} + return '', {'content_source': 'none'} + + +def fetch_single_text(url: str, config: JinaReaderConfig) -> str: + """ + Synchronous fetch of a single URL via Jina Reader with retry/backoff, + then optional direct HTTP fallback when Jina yields empty. + """ + text, _meta = fetch_single_text_with_meta(url, config) + return text + + async def fetch_texts_via_jina( urls: List[str], config: Optional[JinaReaderConfig] = None, diff --git a/ms_agent/tools/mcp_client.py b/ms_agent/tools/mcp_client.py index d6b971fbb..7fbfe148c 100644 --- a/ms_agent/tools/mcp_client.py +++ b/ms_agent/tools/mcp_client.py @@ -2,18 +2,18 @@ import os from contextlib import AsyncExitStack from datetime import timedelta -from types import TracebackType -from typing import Any, Dict, Literal, Optional - from mcp import ClientSession, ListToolsResult, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from omegaconf import DictConfig +from types import TracebackType +from typing import Any, Dict, Literal, Optional + from ms_agent.config import Config from ms_agent.config.env import Env from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import enhance_error, get_logger -from omegaconf import DictConfig logger = get_logger() @@ -184,7 +184,8 @@ async def connect_to_server(self, 'Using streamable_http transport. To configure a different transport such as sse, please' 'set the `type` or `transport` variable to "sse".') try: - from mcp.client.streamable_http import streamablehttp_client + from mcp.client.streamable_http import \ + streamablehttp_client except ImportError: raise ImportError( 'Could not import streamablehttp_client. ' diff --git a/ms_agent/tools/mineru/pdf_parser.py b/ms_agent/tools/mineru/pdf_parser.py index d210fd9b1..284fbff69 100644 --- a/ms_agent/tools/mineru/pdf_parser.py +++ b/ms_agent/tools/mineru/pdf_parser.py @@ -1,5 +1,4 @@ import os - from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.data.data_reader_writer import (FileBasedDataReader, FileBasedDataWriter) diff --git a/ms_agent/tools/search/arxiv/schema.py b/ms_agent/tools/search/arxiv/schema.py index 154c603ae..2e9439904 100644 --- a/ms_agent/tools/search/arxiv/schema.py +++ b/ms_agent/tools/search/arxiv/schema.py @@ -1,10 +1,10 @@ # flake8: noqa -from dataclasses import dataclass, field -from typing import Any, Dict, Generator, List, Optional - import arxiv import json from arxiv import SortCriterion, SortOrder +from dataclasses import dataclass, field +from typing import Any, Dict, Generator, List, Optional + from ms_agent.tools.search.search_base import (BaseResult, SearchRequest, SearchResponse, SearchResult) from ms_agent.utils.logger import get_logger diff --git a/ms_agent/tools/search/arxiv/search.py b/ms_agent/tools/search/arxiv/search.py index 8f407d454..4888c2f3a 100644 --- a/ms_agent/tools/search/arxiv/search.py +++ b/ms_agent/tools/search/arxiv/search.py @@ -1,10 +1,10 @@ # flake8: noqa +import arxiv import os +from arxiv import SortCriterion, SortOrder from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING -import arxiv -from arxiv import SortCriterion, SortOrder from ms_agent.tools.search.arxiv.schema import (ArxivSearchRequest, ArxivSearchResult) from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType diff --git a/ms_agent/tools/search/content_optimizer.py b/ms_agent/tools/search/content_optimizer.py index 998b13b66..1e62798cb 100644 --- a/ms_agent/tools/search/content_optimizer.py +++ b/ms_agent/tools/search/content_optimizer.py @@ -1,19 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import os import re from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from datetime import datetime +from omegaconf import DictConfig, OmegaConf from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse -import json from ms_agent.llm.openai_llm import OpenAI from ms_agent.llm.utils import Message from ms_agent.utils.logger import get_logger from ms_agent.utils.thread_util import DaemonThreadPoolExecutor -from omegaconf import DictConfig, OmegaConf logger = get_logger() diff --git a/ms_agent/tools/search/exa/schema.py b/ms_agent/tools/search/exa/schema.py index a80a1c401..48b51c06b 100644 --- a/ms_agent/tools/search/exa/schema.py +++ b/ms_agent/tools/search/exa/schema.py @@ -1,9 +1,8 @@ # flake8: noqa -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - import json +from dataclasses import dataclass, field from exa_py.api import SearchResponse +from typing import Any, Dict, List, Optional @dataclass diff --git a/ms_agent/tools/search/exa/search.py b/ms_agent/tools/search/exa/search.py index be456c0a5..d09ffa237 100644 --- a/ms_agent/tools/search/exa/search.py +++ b/ms_agent/tools/search/exa/search.py @@ -1,14 +1,18 @@ # flake8: noqa import os -from typing import TYPE_CHECKING - +import threading from exa_py import Exa +from typing import TYPE_CHECKING, List, Optional, Set, Union + from ms_agent.tools.search.exa.schema import ExaSearchRequest, ExaSearchResult from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType +from ms_agent.utils.logger import get_logger if TYPE_CHECKING: from ms_agent.llm.utils import Tool +logger = get_logger() + class ExaSearch(SearchEngine): """ @@ -16,21 +20,129 @@ class ExaSearch(SearchEngine): Best for: semantic understanding, finding similar content, recent web pages with date filtering. + + Supports a pool of API keys: when one key's credits are exhausted (HTTP 402), + the engine automatically rotates to the next available key and retries. + Keys can be supplied as a comma-separated string or a list. + + Exhausted-key state is tracked at the **class level** so that all + ``ExaSearch`` instances within the same process share the knowledge + of which keys have been used up (e.g. across multiple searcher sub-agents). """ engine_type = SearchEngineType.EXA - def __init__(self, api_key: str = None): + # Process-wide tracking of exhausted key values (shared across instances). + _global_exhausted_keys: Set[str] = set() + _global_lock = threading.Lock() + + def __init__(self, + api_key: Union[str, list, None] = None, + api_keys: Union[str, list, None] = None): + all_keys = self._collect_keys(api_key, api_keys) + assert all_keys, ( + 'EXA_API_KEY or EXA_API_KEYS must be set either as arguments ' + 'or as environment variables') + + self._api_keys: List[str] = all_keys + self._lock = threading.Lock() + + # Pick the first key that hasn't been globally exhausted yet. + start_idx = 0 + with ExaSearch._global_lock: + for i, k in enumerate(all_keys): + if k not in ExaSearch._global_exhausted_keys: + start_idx = i + break + + self._current_key_idx: int = start_idx + self.client = Exa(api_key=all_keys[start_idx]) + + if len(all_keys) > 1: + with ExaSearch._global_lock: + n_exhausted = sum(1 for k in all_keys + if k in ExaSearch._global_exhausted_keys) + logger.info(f'Exa key pool: {len(all_keys)} keys, ' + f'{n_exhausted} previously exhausted, ' + f'starting at key {start_idx + 1}/{len(all_keys)}') + + @staticmethod + def _collect_keys( + api_key: Union[str, list, None] = None, + api_keys: Union[str, list, None] = None, + ) -> List[str]: + """Collect unique API keys from arguments and environment variables. + + All sources are **merged** (deduplicated), so keys from both YAML config + and ``EXA_API_KEYS`` env var are combined into a single pool. + + Sources (in merge order): + 1. ``api_keys`` argument (list or comma-separated string) + 2. ``api_key`` argument (single key or comma-separated string) + 3. ``EXA_API_KEYS`` env var (comma-separated) -- always merged + 4. ``EXA_API_KEY`` env var (only if no keys found so far) + """ + seen: set = set() + result: List[str] = [] + + def _add(raw: str): + for k in raw.split(','): + k = k.strip() + if k and k not in seen: + seen.add(k) + result.append(k) + + def _add_source(value): + if value is None: + return + if isinstance(value, str): + _add(value) + elif hasattr(value, '__iter__'): + for item in value: + if item is not None: + _add(str(item)) + + _add_source(api_keys) + _add_source(api_key) + + # Always merge the pool env var so that keys from YAML config and + # EXA_API_KEYS are combined (the old code gated this behind + # ``if not result`` which made it unreachable when api_key was set). + _add_source(os.getenv('EXA_API_KEYS')) + + if not result: + _add_source(os.getenv('EXA_API_KEY')) + + return result + + @staticmethod + def _is_credits_exhausted(error: Exception) -> bool: + """Detect Exa 402 / NO_MORE_CREDITS errors.""" + msg = str(error) + return ('402' in msg + and ('credits' in msg.lower() or 'NO_MORE_CREDITS' in msg)) - api_key = api_key or os.getenv('EXA_API_KEY') - assert api_key, 'EXA_API_KEY must be set either as an argument or as an environment variable' + @staticmethod + def _mask_key(key: str) -> str: + if len(key) <= 8: + return '****' + return f'{key[:4]}...{key[-4:]}' - self.client = Exa(api_key=api_key) + def _is_key_exhausted(self, idx: int) -> bool: + with ExaSearch._global_lock: + return self._api_keys[idx] in ExaSearch._global_exhausted_keys + + def _mark_key_exhausted(self, idx: int) -> None: + with ExaSearch._global_lock: + ExaSearch._global_exhausted_keys.add(self._api_keys[idx]) def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: """ Perform a search using the Exa API with the provided search request parameters. + If the current key is exhausted (HTTP 402 / NO_MORE_CREDITS), the engine + rotates to the next available key and retries, up to ``len(api_keys)`` times. + :param search_request: An instance of ExaSearchRequest containing search parameters. :return: An instance of ExaSearchResult containing the search results. """ @@ -39,13 +151,53 @@ def search(self, search_request: ExaSearchRequest) -> ExaSearchResult: query=search_request.query, arguments=search_args, ) - try: - search_result.response = self.client.search_and_contents( - **search_args) - except Exception as e: - raise RuntimeError(f'Failed to perform search: {e}') from e - return search_result + last_error: Optional[Exception] = None + max_attempts = len(self._api_keys) + instance_exhausted: Set[int] = set() + + for _attempt in range(max_attempts): + with self._lock: + client = self.client + key_idx = self._current_key_idx + + try: + search_result.response = client.search_and_contents( + **search_args) + return search_result + except Exception as e: + if not self._is_credits_exhausted(e): + raise RuntimeError(f'Failed to perform search: {e}') from e + + last_error = e + instance_exhausted.add(key_idx) + self._mark_key_exhausted(key_idx) + + with self._lock: + logger.warning( + f'Exa API key {self._mask_key(self._api_keys[key_idx])} ' + f'credits exhausted ' + f'({len(instance_exhausted)}/{len(self._api_keys)} keys used up)' + ) + rotated = False + for i in range(len(self._api_keys)): + if i not in instance_exhausted and not self._is_key_exhausted( + i): + self._current_key_idx = i + self.client = Exa(api_key=self._api_keys[i]) + logger.info(f'Rotated to Exa API key ' + f'{self._mask_key(self._api_keys[i])} ' + f'({i + 1}/{len(self._api_keys)})') + rotated = True + break + if not rotated: + raise RuntimeError( + f'All {len(self._api_keys)} Exa API keys have ' + f'been exhausted. Last error: {e}') from e + + raise RuntimeError( + f'All {len(self._api_keys)} Exa API keys have been exhausted. ' + f'Last error: {last_error}') from last_error @classmethod def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': diff --git a/ms_agent/tools/search/localsearch_tool.py b/ms_agent/tools/search/localsearch_tool.py new file mode 100644 index 000000000..2449d4453 --- /dev/null +++ b/ms_agent/tools/search/localsearch_tool.py @@ -0,0 +1,282 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""On-demand local codebase search via sirchmunk (replaces pre-turn RAG injection).""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional + +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.tools.search.sirchmunk_search import ( + SirchmunkSearch, effective_localsearch_settings) +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_SERVER = 'localsearch' +_TOOL = 'localsearch' + +# Tool-facing description: aligned with sirchmunk AgenticSearch.search() capabilities. +_LOCALSEARCH_DESCRIPTION = """Search local files, codebases, and documents on disk. + +USE THIS TOOL WHEN: +- The user asks about content in local files or directories +- You need to find information in source code, config files, or documents +- The query references a local path, project structure, or codebase +- You need to search PDF, DOCX, XLSX, PPTX, CSV, JSON, YAML, Markdown, etc. + +DO NOT USE THIS TOOL WHEN: +- The user is asking a general knowledge question +- The user is greeting you or making casual conversation (e.g., "你好", "hello") +- You need information from the internet or recent events +- The query has no relation to local files or code + +Returns: +Search results after summarizing as formatted text with file paths, code snippets, and explanations where +available. Retrieved excerpts and meta are included in the tool output. + +Configured search roots for this agent (absolute paths; default search scope when `paths` is omitted): +{configured_roots} +""" + + +def _resolved_localsearch_paths_from_config(config) -> List[str]: + """Match ``SirchmunkSearch`` path resolution for consistent tool text and checks.""" + block = effective_localsearch_settings(config) + if not block: + return [] + paths = block.get('paths', []) + if isinstance(paths, str): + paths = [paths] + out: List[str] = [] + for p in paths or []: + if p is None or not str(p).strip(): + continue + out.append(str(Path(str(p).strip()).expanduser().resolve())) + return out + + +def _format_configured_roots(paths: List[str]) -> str: + if not paths: + return ('(none — set tools.localsearch.paths in agent config, ' + 'or legacy knowledge_search.paths)') + return '\n'.join(f'- {p}' for p in paths) + + +def _json_dumps(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, indent=2) + + +def _as_str_list(value: Any, name: str) -> Optional[List[str]]: + if value is None: + return None + if isinstance(value, str): + return [value] if value.strip() else None + if isinstance(value, list): + out = [str(x).strip() for x in value if str(x).strip()] + return out or None + raise TypeError(f'{name} must be a string or list of strings') + + +class LocalSearchTool(ToolBase): + """Expose sirchmunk as a callable tool when ``tools.localsearch`` is configured.""" + + def __init__(self, config, **kwargs): + super().__init__(config) + tools_root = getattr(config, 'tools', None) + tool_cfg = getattr(tools_root, 'localsearch', + None) if tools_root else None + if tool_cfg is not None: + self.exclude_func(tool_cfg) + self._searcher: Optional[SirchmunkSearch] = None + self._configured_roots: List[str] = ( + _resolved_localsearch_paths_from_config(config)) + + def _tool_description(self) -> str: + return _LOCALSEARCH_DESCRIPTION.format( + configured_roots=_format_configured_roots(self._configured_roots)) + + def _paths_param_description(self) -> str: + roots = _format_configured_roots(self._configured_roots) + return ( + 'Optional. Narrow search to specific files or directories under the ' + 'configured roots below. Each path must exist on disk and lie under ' + 'one of these roots (or be exactly one of them).\n' + f'Configured roots:\n{roots}') + + def _ensure_searcher(self) -> SirchmunkSearch: + if self._searcher is None: + self._searcher = SirchmunkSearch(self.config) + return self._searcher + + async def connect(self) -> None: + return None + + async def _get_tools_inner(self) -> Dict[str, List[Tool]]: + return { + _SERVER: [ + Tool( + tool_name=_TOOL, + server_name=_SERVER, + description=self._tool_description(), + parameters={ + 'type': 'object', + 'properties': { + 'query': { + 'type': + 'string', + 'description': + 'Search keywords or natural-language question about local content.', + }, + 'paths': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': self._paths_param_description(), + }, + 'mode': { + 'type': + 'string', + 'enum': ['FAST', 'DEEP', 'FILENAME_ONLY'], + 'description': + 'Search mode; omit to use agent default (usually FAST).', + }, + 'max_depth': { + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, + 'description': + 'Max directory depth for filesystem search.', + }, + 'top_k_files': { + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, + 'description': + 'Max files for evidence / filename hits.', + }, + 'include': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to include (e.g. *.py, *.md).', + }, + 'exclude': { + 'type': + 'array', + 'items': { + 'type': 'string' + }, + 'description': + 'Glob patterns to exclude (e.g. *.pyc).', + }, + }, + 'required': ['query'], + }, + ) + ] + } + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict): + del server_name + if tool_name != _TOOL: + return f'Unknown tool: {tool_name}' + + args = tool_args or {} + query = str(args.get('query', '')).strip() + if not query: + return 'Error: `query` is required and cannot be empty.' + + try: + paths_arg = _as_str_list(args.get('paths'), 'paths') + mode = args.get('mode') + if mode is not None: + mode = str(mode).strip().upper() or None + + max_depth = args.get('max_depth') + if max_depth is not None: + max_depth = int(max_depth) + max_depth = max(1, min(20, max_depth)) + + top_k = args.get('top_k_files') + if top_k is not None: + top_k = int(top_k) + top_k = max(1, min(20, top_k)) + + include = _as_str_list(args.get('include'), 'include') + exclude = _as_str_list(args.get('exclude'), 'exclude') + + searcher = self._ensure_searcher() + resolved_paths = None + if paths_arg: + resolved_paths = searcher.resolve_tool_paths(paths_arg) + if not resolved_paths: + roots = _format_configured_roots(self._configured_roots) + return ( + 'Error: `paths` are invalid. Each path must exist on disk and lie ' + 'under one of these configured roots:\n' + roots) + + answer = await searcher.query( + query, + paths=resolved_paths, + mode=mode, + max_depth=max_depth, + top_k_files=top_k, + include=include, + exclude=exclude, + ) + details = searcher.get_search_details() + excerpts = searcher.get_last_retrieved_chunks() + + lines = ['## Local search (sirchmunk)', '', str(answer), ''] + + if excerpts: + lines.append('## Retrieved excerpts') + lines.append('') + for i, item in enumerate(excerpts[:12], 1): + meta = item.get('metadata') or {} + src = meta.get('source', '?') + text = (item.get('text') or '')[:4000] + lines.append(f'### [{i}] {src}') + lines.append(text) + lines.append('') + + summary = { + 'mode': details.get('mode'), + 'paths': details.get('paths'), + 'work_path': details.get('work_path'), + 'cluster_cache_hit': details.get('cluster_cache_hit'), + } + lines.append('## Meta') + lines.append(_json_dumps(summary)) + + full_text = '\n'.join(lines) + # Model sees answer + source paths only; UI gets full excerpts + meta. + result_parts = [str(answer).strip()] + if excerpts: + result_parts.append('\nSource paths:') + for item in excerpts[:12]: + meta = item.get('metadata') or {} + result_parts.append(f'- {meta.get("source", "?")}') + result_text = '\n'.join(result_parts) + + return { + 'result': result_text, + 'tool_detail': full_text, + } + except (TypeError, ValueError) as exc: + return f'Invalid tool arguments: {exc}' + except Exception as exc: + logger.warning(f'localsearch failed: {exc}') + return f'Local search failed: {exc}' diff --git a/ms_agent/tools/search/search_base.py b/ms_agent/tools/search/search_base.py index bb6952729..3c125f11d 100644 --- a/ms_agent/tools/search/search_base.py +++ b/ms_agent/tools/search/search_base.py @@ -1,12 +1,11 @@ # flake8: noqa import enum +import json import os from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar -import json - if TYPE_CHECKING: from ms_agent.llm.utils import Tool @@ -17,6 +16,7 @@ class SearchEngineType(enum.Enum): EXA = 'exa' SERPAPI = 'serpapi' ARXIV = 'arxiv' + TAVILY = 'tavily' # Mapping from engine type to tool name @@ -24,6 +24,7 @@ class SearchEngineType(enum.Enum): 'exa': 'exa_search', 'serpapi': 'serpapi_search', 'arxiv': 'arxiv_search', + 'tavily': 'tavily_search', } diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/tools/search/sirchmunk_search.py similarity index 68% rename from ms_agent/knowledge_search/sirchmunk_search.py rename to ms_agent/tools/search/sirchmunk_search.py index e1c76181f..cd8aaacf5 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/tools/search/sirchmunk_search.py @@ -1,51 +1,82 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Sirchmunk-based knowledge search integration. +"""Sirchmunk backend for the ``localsearch`` tool. -This module wraps sirchmunk's AgenticSearch to work with the ms_agent framework, -providing document retrieval capabilities similar to RAG but optimized for -codebase and documentation search. +Configuration lives under ``tools.localsearch`` (same namespace as other tools). +Legacy top-level ``knowledge_search`` is still accepted for backward compatibility. """ import asyncio -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union - +import json from loguru import logger -from ms_agent.rag.base import RAG from omegaconf import DictConfig +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional -class SirchmunkSearch(RAG): - """Sirchmunk-based knowledge search class. +def _paths_from_block(block: Any) -> List[str]: + if block is None: + return [] + paths = block.get('paths', []) if hasattr(block, 'get') else [] + if isinstance(paths, str): + paths = [paths] if str(paths).strip() else [] + out: List[str] = [] + for p in paths or []: + if p is None or not str(p).strip(): + continue + out.append(str(p).strip()) + return out - This class wraps the sirchmunk library to provide intelligent codebase search - capabilities. Unlike traditional RAG that uses vector embeddings, Sirchmunk - uses a combination of keyword search, semantic clustering, and LLM-powered - analysis to find relevant information from codebases. - The configuration needed in the config yaml: - - name: SirchmunkSearch - - paths: List of paths to search, required - - work_path: Working directory for sirchmunk cache, default './.sirchmunk' - - embedding_model: Embedding model for clustering, default 'text-embedding-3-small' - - cluster_sim_threshold: Threshold for cluster similarity, default 0.85 - - cluster_sim_top_k: Top K clusters to consider, default 3 - - reuse_knowledge: Whether to reuse previous search results, default True - - mode: Search mode (DEEP, FAST, FILENAME_ONLY), default 'FAST' +def effective_localsearch_settings(config: DictConfig) -> Optional[Any]: + """Resolve the active localsearch / sirchmunk settings node. + + Precedence: ``tools.localsearch`` with non-empty ``paths``, else legacy + ``knowledge_search`` with non-empty ``paths``. Returns ``None`` if local + search is not configured. + """ + tools = getattr(config, 'tools', None) + tl = None + if tools is not None: + tl = tools.get('localsearch') if hasattr(tools, 'get') else getattr( + tools, 'localsearch', None) + ks = getattr(config, 'knowledge_search', None) + + if tl is not None and _paths_from_block(tl): + return tl + if ks is not None and _paths_from_block(ks): + return ks + return None + + +class SirchmunkSearch: + """Sirchmunk-based local search (used by :class:`LocalSearchTool`). + + Configure in yaml under ``tools.localsearch`` (recommended), for example:: + + tools: + localsearch: + paths: + - ./src + - ./docs + work_path: ./.sirchmunk + embedding_model: text-embedding-3-small + cluster_sim_threshold: 0.85 + cluster_sim_top_k: 3 + reuse_knowledge: true + mode: FAST + + Legacy: the same keys may be placed under top-level ``knowledge_search``. Args: - config (DictConfig): Configuration object containing sirchmunk settings. + config: Full agent config; sirchmunk options read from the effective + block returned by :func:`effective_localsearch_settings`. """ def __init__(self, config: DictConfig): - super().__init__(config) - self._validate_config(config) + rag_config = effective_localsearch_settings(config) + assert rag_config is not None - # Extract configuration parameters - rag_config = config.get('knowledge_search', {}) - - # Search paths - required paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] @@ -53,11 +84,9 @@ def __init__(self, config: DictConfig): str(Path(p).expanduser().resolve()) for p in paths ] - # Work path for sirchmunk cache _work_path = rag_config.get('work_path', './.sirchmunk') self.work_path: Path = Path(_work_path).expanduser().resolve() - # Sirchmunk search parameters self.reuse_knowledge = rag_config.get('reuse_knowledge', True) self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) @@ -66,13 +95,10 @@ def __init__(self, config: DictConfig): self.max_loops = rag_config.get('max_loops', 10) self.max_token_budget = rag_config.get('max_token_budget', 128000) - # LLM configuration for sirchmunk - # First try knowledge_search.llm_api_key, then fall back to main llm config self.llm_api_key = rag_config.get('llm_api_key', None) self.llm_base_url = rag_config.get('llm_base_url', None) self.llm_model_name = rag_config.get('llm_model_name', None) - # Fall back to main llm config if not specified in knowledge_search if (self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None): llm_config = config.get('llm', {}) @@ -87,39 +113,57 @@ def __init__(self, config: DictConfig): if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) - # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) self.embedding_model_cache_dir = rag_config.get( 'embedding_model_cache_dir', None) - # Runtime state self._searcher = None self._initialized = False self._cluster_cache_hit = False self._cluster_cache_hit_time: str | None = None self._last_search_result: List[Dict[str, Any]] | None = None - # Callback for capturing logs self._log_callback = None self._search_logs: List[str] = [] - # Async queue for streaming logs in real-time self._log_queue: asyncio.Queue | None = None self._streaming_callback: Callable | None = None def _validate_config(self, config: DictConfig): - """Validate configuration parameters.""" - if not hasattr(config, - 'knowledge_search') or config.knowledge_search is None: + block = effective_localsearch_settings(config) + if block is None: raise ValueError( - 'Missing knowledge_search configuration. ' - 'Please add knowledge_search section to your config with at least "paths" specified.' - ) - - rag_config = config.knowledge_search - paths = rag_config.get('paths', []) + 'Missing localsearch configuration. Add ' + '`tools.localsearch` with non-empty `paths` (or legacy ' + '`knowledge_search.paths`).') + paths = _paths_from_block(block) if not paths: raise ValueError( - 'knowledge_search.paths must be specified and non-empty') + 'tools.localsearch.paths (or legacy knowledge_search.paths) ' + 'must be specified and non-empty') + + def resolve_tool_paths(self, + paths: Optional[List[str]]) -> Optional[List[str]]: + """Restrict per-call paths to configured search roots.""" + if not paths: + return None + roots = [Path(p).resolve() for p in self.search_paths] + cleaned: List[str] = [] + for raw in paths: + if raw is None or not str(raw).strip(): + continue + p = Path(str(raw).strip()).expanduser().resolve() + if not p.exists(): + logger.warning( + f'localsearch: path does not exist, skipped: {p}') + continue + allowed = any(p == r or p.is_relative_to(r) for r in roots) + if not allowed: + logger.warning( + f'localsearch: path outside configured search roots, ' + f'skipped: {p}') + continue + cleaned.append(str(p)) + return cleaned or None def _initialize_searcher(self): """Initialize the sirchmunk AgenticSearch instance.""" @@ -131,7 +175,6 @@ def _initialize_searcher(self): from sirchmunk.search import AgenticSearch from sirchmunk.utils.embedding_util import EmbeddingUtil - # Create LLM client llm = OpenAIChat( api_key=self.llm_api_key, base_url=self.llm_base_url, @@ -140,8 +183,6 @@ def _initialize_searcher(self): log_callback=self._log_callback_wrapper(), ) - # Create embedding util - # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) embedding_model_id = ( self.embedding_model_id if self.embedding_model_id else None) embedding_cache_dir = ( @@ -150,7 +191,6 @@ def _initialize_searcher(self): embedding = EmbeddingUtil( model_id=embedding_model_id, cache_dir=embedding_cache_dir) - # Create AgenticSearch instance self._searcher = AgenticSearch( llm=llm, embedding=embedding, @@ -191,7 +231,6 @@ def log_callback( ): log_entry = f'[{level.upper()}] {message}' self._search_logs.append(log_entry) - # Stream log in real-time if streaming callback is set if self._streaming_callback: asyncio.create_task(self._streaming_callback(log_entry)) @@ -215,7 +254,6 @@ async def add_documents(self, documents: List[str]) -> bool: 'SirchmunkSearch does not support direct document addition. ' 'Documents should be saved to files within the configured search paths.' ) - # Trigger re-scan of the search paths if self._searcher and hasattr(self._searcher, 'knowledge_base'): try: await self._searcher.knowledge_base.refresh() @@ -274,7 +312,6 @@ async def retrieve(self, max_token_budget = filters.get('max_token_budget', self.max_token_budget) - # Perform search result = await self._searcher.search( query=query, mode=mode, @@ -283,56 +320,102 @@ async def retrieve(self, return_context=True, ) - # Check if cluster cache was hit self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - # If a similar cluster was found and reused, it's a cache hit self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) - # Get the cluster cache hit time if available if hasattr(result.cluster, 'updated_at'): self._cluster_cache_hit_time = getattr( result.cluster, 'updated_at', None) - # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: logger.error(f'SirschmunkSearch retrieve failed: {e}') return [] - async def query(self, query: str) -> str: - """Query sirchmunk and return a synthesized answer. - - This method performs a search and returns the LLM-synthesized answer - along with search details that can be used for frontend display. + async def query( + self, + query: str, + *, + paths: Optional[List[str]] = None, + mode: Optional[str] = None, + max_depth: Optional[int] = None, + top_k_files: Optional[int] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ) -> str: + """Query sirchmunk and return a synthesized answer (or filename hits). + + Optional arguments are forwarded to ``AgenticSearch.search`` where supported. + ``paths`` must already be restricted to configured search roots (see + :meth:`resolve_tool_paths`). Args: - query (str): The search query. + query: The search query. + paths: Override search roots (subset of configured paths), or None. + mode: ``FAST``, ``DEEP``, or ``FILENAME_ONLY``; None uses config default. + max_depth: Directory depth cap for filesystem search. + top_k_files: Max files for evidence / filename ranking. + include: Glob patterns to include (e.g. ``*.py``). + exclude: Glob patterns to exclude (e.g. ``node_modules``). Returns: - str: The synthesized answer from sirchmunk. + Answer string, or JSON string for ``FILENAME_ONLY`` list results. """ self._initialize_searcher() self._search_logs.clear() try: - mode = self.search_mode - max_loops = self.max_loops - max_token_budget = self.max_token_budget - - # Single search with context so we get both the synthesized answer and - # source units in one call, avoiding a redundant second search. - result = await self._searcher.search( + mode_eff = mode if mode is not None else self.search_mode + if isinstance(mode_eff, str): + mode_eff = mode_eff.strip().upper() + allowed_modes = ('FAST', 'DEEP', 'FILENAME_ONLY') + if mode_eff not in allowed_modes: + return ( + f'Invalid mode {mode_eff!r}; use one of {allowed_modes}.') + + kw: Dict[str, Any] = dict( query=query, - mode=mode, - max_loops=max_loops, - max_token_budget=max_token_budget, + paths=paths, + mode=mode_eff, + max_loops=self.max_loops, + max_token_budget=self.max_token_budget, return_context=True, ) + if max_depth is not None: + kw['max_depth'] = max_depth + if top_k_files is not None: + kw['top_k_files'] = top_k_files + if include is not None: + kw['include'] = include + if exclude is not None: + kw['exclude'] = exclude + + result = await self._searcher.search(**kw) + + if isinstance(result, list): + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + self._last_search_result = [] + for item in result[:20]: + if isinstance(item, dict): + src = ( + item.get('path') or item.get('file_path') + or item.get('file') or '') + self._last_search_result.append({ + 'text': + json.dumps(item, ensure_ascii=False), + 'score': + 1.0, + 'metadata': { + 'source': str(src), + 'type': 'filename_match', + }, + }) + return json.dumps(result, ensure_ascii=False, indent=2) - # Check if cluster cache was hit self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: @@ -342,19 +425,16 @@ async def query(self, query: str) -> str: self._cluster_cache_hit_time = getattr( result.cluster, 'updated_at', None) - # Store parsed context for frontend display self._last_search_result = self._parse_search_result( result, score_threshold=0.7, limit=5) - # Extract the synthesized answer from the context result - if hasattr(result, 'answer'): + if hasattr(result, 'answer') and getattr(result, 'answer', + None) is not None: return result.answer - # If result is already a plain string (some modes return str directly) if isinstance(result, str): return result - # Fallback: convert to string return str(result) except Exception as e: @@ -375,14 +455,11 @@ def _parse_search_result(self, result: Any, score_threshold: float, """ results = [] - # Handle SearchContext format (returned when return_context=True) if hasattr(result, 'cluster') and result.cluster is not None: cluster = result.cluster for unit in cluster.evidences: - # Extract score from snippets if available score = getattr(cluster, 'confidence', 1.0) if score >= score_threshold: - # Extract text from snippets text_parts = [] source = str(getattr(unit, 'file_or_url', 'unknown')) for snippet in getattr(unit, 'snippets', []): @@ -406,7 +483,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, }, }) - # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) @@ -423,7 +499,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, }, }) - # Handle list format elif isinstance(result, list): for item in result: if isinstance(item, dict): @@ -438,7 +513,6 @@ def _parse_search_result(self, result: Any, score_threshold: float, item.get('metadata', {}), }) - # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: @@ -451,10 +525,13 @@ def _parse_search_result(self, result: Any, score_threshold: float, result.get('metadata', {}), }) - # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) return results[:limit] + def get_last_retrieved_chunks(self) -> List[Dict[str, Any]]: + """Parsed evidence chunks from the last `query` or `retrieve` call.""" + return list(self._last_search_result or []) + def get_search_logs(self) -> List[str]: """Get the captured search logs. diff --git a/ms_agent/tools/search/tavily/__init__.py b/ms_agent/tools/search/tavily/__init__.py new file mode 100644 index 000000000..ef3e913f2 --- /dev/null +++ b/ms_agent/tools/search/tavily/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from ms_agent.tools.search.tavily.search import TavilySearch + +__all__ = ['TavilySearch'] diff --git a/ms_agent/tools/search/tavily/fetcher.py b/ms_agent/tools/search/tavily/fetcher.py new file mode 100644 index 000000000..d4ed4927a --- /dev/null +++ b/ms_agent/tools/search/tavily/fetcher.py @@ -0,0 +1,94 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tavily Extract API as ContentFetcher (replaces Jina for fetch_page / URL fetch).""" +import os +import time +from typing import Any, Dict, Optional, Tuple + +from ms_agent.tools.search.tavily.http import post_json +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +TAVILY_EXTRACT_URL = 'https://api.tavily.com/extract' + + +class TavilyExtractFetcher: + """ + Fetch page text via Tavily POST /extract. + + Uses ``extract_depth=advanced`` and ``format=markdown`` by default for + richest structured text (tables, etc., when available). + """ + + def __init__( + self, + api_key: Optional[str] = None, + *, + extract_depth: str = 'advanced', + format: str = 'markdown', + timeout: float = 45.0, + chunks_per_source: int = 3, + include_images: bool = False, + include_favicon: bool = False, + include_usage: bool = False, + ): + key = api_key or os.getenv('TAVILY_API_KEY') + if not key: + raise ValueError( + 'TAVILY_API_KEY required for tavily_extract fetcher') + self._api_key = key + self._extract_depth = extract_depth + self._format = format + self._timeout = max(1.0, min(60.0, float(timeout))) + self._chunks_per_source = max(1, min(5, int(chunks_per_source))) + self._include_images = include_images + self._include_favicon = include_favicon + self._include_usage = include_usage + + def fetch(self, + url: str, + query: Optional[str] = None) -> Tuple[str, Dict[str, Any]]: + """ + Extract one URL. Optional ``query`` enables chunk reranking (more relevant raw_content). + """ + body: Dict[str, Any] = { + 'api_key': self._api_key, + 'urls': [url], + 'extract_depth': self._extract_depth, + 'format': self._format, + 'timeout': self._timeout, + 'include_images': self._include_images, + 'include_favicon': self._include_favicon, + 'include_usage': self._include_usage, + } + if query: + body['query'] = query + body['chunks_per_source'] = self._chunks_per_source + + try: + data = post_json( + TAVILY_EXTRACT_URL, body, timeout=self._timeout + 30.0) + except Exception as e: + logger.warning(f'Tavily extract failed for {url[:80]}: {e}') + return '', { + 'fetcher': 'tavily_extract', + 'error': str(e), + 'fetched_at': time.strftime('%Y-%m-%dT%H:%M:%S'), + } + + results = data.get('results') or [] + text = '' + if results: + text = (results[0].get('raw_content') or '').strip() + meta: Dict[str, Any] = { + 'fetcher': 'tavily_extract', + 'fetched_at': time.strftime('%Y-%m-%dT%H:%M:%S'), + 'tavily_response_time': data.get('response_time'), + 'tavily_usage': data.get('usage'), + 'tavily_request_id': data.get('request_id'), + } + failed = data.get('failed_results') or [] + if failed and not text: + err = failed[0].get('error', 'unknown') + meta['error'] = err + return text, meta diff --git a/ms_agent/tools/search/tavily/http.py b/ms_agent/tools/search/tavily/http.py new file mode 100644 index 000000000..e2ad2cb85 --- /dev/null +++ b/ms_agent/tools/search/tavily/http.py @@ -0,0 +1,49 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Minimal HTTP JSON client for Tavily REST API (stdlib only).""" +import json +from typing import Any, Dict +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + + +def post_json( + url: str, + body: Dict[str, Any], + *, + timeout: float = 120.0, +) -> Dict[str, Any]: + """ + POST JSON and parse JSON response. + + Raises: + RuntimeError: on HTTP errors or invalid JSON (includes Tavily error body). + """ + data = json.dumps(body, ensure_ascii=False).encode('utf-8') + req = Request( + url, + data=data, + method='POST', + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + ) + try: + with urlopen(req, timeout=timeout) as resp: + raw = resp.read().decode('utf-8', errors='replace') + if not raw.strip(): + return {} + return json.loads(raw) + except HTTPError as e: + err_body = '' + try: + err_body = e.read().decode('utf-8', errors='replace') + except Exception: + pass + try: + detail = json.loads(err_body) if err_body else {} + except json.JSONDecodeError: + detail = {'raw': err_body} + raise RuntimeError(f'Tavily HTTP {e.code}: {detail}') from e + except URLError as e: + raise RuntimeError(f'Tavily network error: {e}') from e diff --git a/ms_agent/tools/search/tavily/schema.py b/ms_agent/tools/search/tavily/schema.py new file mode 100644 index 000000000..112a41839 --- /dev/null +++ b/ms_agent/tools/search/tavily/schema.py @@ -0,0 +1,122 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class TavilySearchRequest: + """Tavily POST /search body. See https://docs.tavily.com/documentation/api-reference/endpoint/search""" + + query: str + max_results: int = 10 + search_depth: str = 'advanced' + chunks_per_source: int = 3 + topic: str = 'general' + time_range: Optional[str] = None + start_date: Optional[str] = None + end_date: Optional[str] = None + # include_answer: false | true | basic | advanced + include_answer: Any = 'advanced' + # include_raw_content: false | true | markdown | text — use markdown for richest text + include_raw_content: Any = 'markdown' + include_images: bool = False + include_image_descriptions: bool = False + include_favicon: bool = False + include_domains: List[str] = field(default_factory=list) + exclude_domains: List[str] = field(default_factory=list) + country: Optional[str] = None + auto_parameters: bool = False + exact_match: bool = False + include_usage: bool = False + safe_search: bool = False + + def to_api_body(self, api_key: str) -> Dict[str, Any]: + n = max(0, min(20, int(self.max_results))) + body: Dict[str, Any] = { + 'api_key': api_key, + 'query': self.query, + 'max_results': n, + 'search_depth': self.search_depth, + 'topic': self.topic, + 'include_answer': self.include_answer, + 'include_raw_content': self.include_raw_content, + 'include_images': self.include_images, + 'include_image_descriptions': self.include_image_descriptions, + 'include_favicon': self.include_favicon, + 'auto_parameters': self.auto_parameters, + 'exact_match': self.exact_match, + 'include_usage': self.include_usage, + 'safe_search': self.safe_search, + } + # chunks_per_source only meaningful for advanced (per Tavily docs) + if self.search_depth == 'advanced': + body['chunks_per_source'] = max( + 1, min(3, int(self.chunks_per_source))) + if self.time_range: + body['time_range'] = self.time_range + if self.start_date: + body['start_date'] = self.start_date + if self.end_date: + body['end_date'] = self.end_date + if self.include_domains: + body['include_domains'] = list(self.include_domains)[:300] + if self.exclude_domains: + body['exclude_domains'] = list(self.exclude_domains)[:150] + if self.country: + body['country'] = self.country + return body + + +@dataclass +class TavilySearchResult: + """Parsed Tavily /search JSON.""" + + query: str + arguments: Dict[str, Any] + response: Dict[str, Any] + + def to_list(self) -> List[Dict[str, Any]]: + """Normalize to WebSearchTool pipeline dicts (prefill content when raw_content present).""" + if not self.response: + return [] + rows: List[Dict[str, Any]] = [] + for r in self.response.get('results') or []: + url = r.get('url') or '' + title = r.get('title') or '' + snippet = (r.get('content') or '').strip() + raw = (r.get('raw_content') or '').strip() + # Prefer full page text for downstream summarization; fallback to snippets + body = raw if raw else snippet + rows.append({ + 'url': url, + 'id': url, + 'title': title, + 'highlights': None, + 'highlight_scores': None, + 'summary': snippet, + 'markdown': raw if raw else None, + # Pipeline uses these keys: + 'content': body, + 'fetch_success': bool(raw), + 'score': r.get('score'), + 'tavily_images': r.get('images') or [], + 'favicon': r.get('favicon'), + }) + return rows + + def extra_response_fields(self) -> Dict[str, Any]: + """Top-level fields to merge into web_search JSON output.""" + if not self.response: + return {} + out: Dict[str, Any] = {} + if self.response.get('answer'): + out['tavily_answer'] = self.response['answer'] + if self.response.get('images'): + out['tavily_images'] = self.response['images'] + if self.response.get('response_time') is not None: + out['tavily_response_time'] = self.response['response_time'] + if self.response.get('usage'): + out['tavily_usage'] = self.response['usage'] + if self.response.get('request_id'): + out['tavily_request_id'] = self.response['request_id'] + return out diff --git a/ms_agent/tools/search/tavily/search.py b/ms_agent/tools/search/tavily/search.py new file mode 100644 index 000000000..38c3c69cb --- /dev/null +++ b/ms_agent/tools/search/tavily/search.py @@ -0,0 +1,234 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os +from typing import TYPE_CHECKING, Any, Dict, Optional + +from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType +from ms_agent.tools.search.tavily.http import post_json +from ms_agent.tools.search.tavily.schema import (TavilySearchRequest, + TavilySearchResult) +from ms_agent.utils.logger import get_logger + +if TYPE_CHECKING: + from ms_agent.llm.utils import Tool + +logger = get_logger() + +TAVILY_SEARCH_URL = 'https://api.tavily.com/search' + + +class TavilySearch(SearchEngine): + """ + Tavily Search API — optimized for LLM agents. + + Defaults favor maximum usable text: ``search_depth=advanced``, + ``include_raw_content=markdown``, ``include_answer=advanced``, + ``chunks_per_source=3`` (capped by Tavily). + """ + + engine_type = SearchEngineType.TAVILY + + def __init__( + self, + api_key: Optional[str] = None, + request_timeout: float = 120.0, + ): + key = api_key or os.getenv('TAVILY_API_KEY') + if not key: + raise ValueError( + 'TAVILY_API_KEY must be set in environment or web_search.tavily_api_key' + ) + self._api_key = key + self._request_timeout = float(request_timeout) + + def search(self, + search_request: TavilySearchRequest) -> TavilySearchResult: + body = search_request.to_api_body(self._api_key) + try: + data = post_json( + TAVILY_SEARCH_URL, body, timeout=self._request_timeout) + except Exception as e: + raise RuntimeError(f'Tavily search failed: {e}') from e + safe_args = {k: v for k, v in body.items() if k != 'api_key'} + safe_args['api_key'] = '' + return TavilySearchResult( + query=search_request.query, + arguments=safe_args, + response=data or {}, + ) + + @classmethod + def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': + from ms_agent.llm.utils import Tool + return Tool( + tool_name=cls.get_tool_name(), + server_name=server_name, + description=( + 'Search the web using Tavily (built for AI agents). ' + 'Returns ranked results with optional full-page markdown via ' + '`include_raw_content`. Use `search_depth` advanced for best ' + 'relevance and richer `content` chunks (higher API credit use).' + ), + parameters={ + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': 'Search query.', + }, + 'num_results': { + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 20, + 'description': + 'Max results (maps to Tavily max_results). Default 10.', + }, + 'search_depth': { + 'type': + 'string', + 'enum': ['advanced', 'basic', 'fast', 'ultra-fast'], + 'description': + ('advanced: best quality, 2 credits; ' + 'basic/fast/ultra-fast: 1 credit (see Tavily docs).'), + }, + 'topic': { + 'type': + 'string', + 'enum': ['general', 'news', 'finance'], + 'description': + 'Search category (`news` / `finance` for focused verticals).', + }, + 'time_range': { + 'type': + 'string', + 'description': + ('Filter by recency: day, week, month, year or d,w,m,y.' + ), + }, + 'start_date': { + 'type': 'string', + 'description': 'Results after YYYY-MM-DD.', + }, + 'end_date': { + 'type': 'string', + 'description': 'Results before YYYY-MM-DD.', + }, + 'include_answer': { + 'type': + 'string', + 'enum': ['false', 'true', 'basic', 'advanced'], + 'description': + ('LLM answer: true/basic for short, advanced for detailed. ' + 'Use false to skip.'), + }, + 'include_raw_content': { + 'type': + 'string', + 'enum': ['false', 'true', 'markdown', 'text'], + 'description': + ('full page text: markdown (recommended) or text; ' + 'false to skip raw content.'), + }, + 'chunks_per_source': { + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 3, + 'description': + ('Relevant chunks per URL when search_depth=advanced. ' + 'Each chunk up to ~500 chars in `content` field.'), + }, + 'include_domains': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': 'Only include these domains (max 300).', + }, + 'exclude_domains': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': 'Exclude domains (max 150).', + }, + 'country': { + 'type': + 'string', + 'description': + ('Boost results from country (e.g. united states). ' + 'See Tavily docs for enum.'), + }, + 'exact_match': { + 'type': + 'boolean', + 'description': + 'Only results with exact quoted phrases in query.', + }, + }, + 'required': ['query'], + }, + ) + + @classmethod + def build_request_from_args(cls, **kwargs: Any) -> TavilySearchRequest: + """Build from merged tool args + YAML defaults (see WebSearchTool).""" + + def _boolish(name: str, default: Any) -> Any: + if name not in kwargs: + return default + v = kwargs[name] + if isinstance(v, str) and v.lower() in ('false', 'true'): + return v.lower() == 'true' + return v + + num = kwargs.get('num_results', kwargs.get('max_results', 10)) + try: + num = int(num) + except (TypeError, ValueError): + num = 10 + + try: + cps = int(kwargs.get('chunks_per_source', 3)) + except (TypeError, ValueError): + cps = 3 + + inc_ans = kwargs.get('include_answer', 'advanced') + if isinstance(inc_ans, str) and inc_ans.lower() == 'false': + inc_ans = False + elif isinstance(inc_ans, str) and inc_ans.lower() == 'true': + inc_ans = True + + inc_raw = kwargs.get('include_raw_content', 'markdown') + if isinstance(inc_raw, str) and inc_raw.lower() == 'false': + inc_raw = False + elif isinstance(inc_raw, str) and inc_raw.lower() == 'true': + inc_raw = 'markdown' + + return TavilySearchRequest( + query=kwargs['query'], + max_results=num, + search_depth=str(kwargs.get('search_depth', 'advanced')), + chunks_per_source=cps, + topic=str(kwargs.get('topic', 'general')), + time_range=kwargs.get('time_range'), + start_date=kwargs.get('start_date'), + end_date=kwargs.get('end_date'), + include_answer=inc_ans, + include_raw_content=inc_raw, + include_images=bool(_boolish('include_images', False)), + include_image_descriptions=bool( + _boolish('include_image_descriptions', False)), + include_favicon=bool(_boolish('include_favicon', False)), + include_domains=list(kwargs.get('include_domains') or []), + exclude_domains=list(kwargs.get('exclude_domains') or []), + country=kwargs.get('country'), + auto_parameters=bool(_boolish('auto_parameters', False)), + exact_match=bool(_boolish('exact_match', False)), + include_usage=bool(_boolish('include_usage', False)), + safe_search=bool(_boolish('safe_search', False)), + ) diff --git a/ms_agent/tools/search/web_search_spill.py b/ms_agent/tools/search/web_search_spill.py new file mode 100644 index 000000000..11a29bf16 --- /dev/null +++ b/ms_agent/tools/search/web_search_spill.py @@ -0,0 +1,284 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Offload oversized web_search tool payloads to disk so LLM context stays bounded. + +When +---- +Estimated inline character volume (per-result ``content``, ``summary``, ``abstract``, +and ``chunks[*].content``) exceeds ``spill_max_inline_chars``. + +Where / lifecycle +----------------- +``{output_dir}/{spill_subdir}/{run_key}/`` — same lifecycle as the task +``output_dir`` (delete the run workdir to reclaim space; no automatic pruning). + +Naming +------ +``run_key = {UTC compact}_{call_id_or_random}`` — unique, sortable, filesystem-safe. + +Files +----- +* ``manifest.json`` — index (query, engine, URLs, paths, sizes, previews). +* ``bodies/{i:03d}.md`` — UTF-8 full text for spilled rows (sections for content / + abstract / chunks). + +Return payload +-------------- +JSON gains ``spill`` with ``digest`` (instructions + quick index) and paths relative +to ``output_dir`` so ``read_file`` can open them. +""" +from __future__ import annotations + +import copy +import json +import os +import re +import time +import uuid +from typing import Any, Dict, List, Tuple + + +def _item_inline_chars(it: Dict[str, Any]) -> int: + n = 0 + for k in ('content', 'summary', 'abstract'): + v = it.get(k) + if isinstance(v, str): + n += len(v) + chunks = it.get('chunks') + if isinstance(chunks, list): + for c in chunks: + if isinstance(c, dict): + s = c.get('content') + if isinstance(s, str): + n += len(s) + return n + + +def _total_inline_chars(items: List[Dict[str, Any]]) -> int: + return sum(_item_inline_chars(x) for x in items) + + +def _preview(text: str, max_chars: int) -> str: + t = (text or '').strip() + if len(t) <= max_chars: + return t + return t[:max_chars].rstrip() + '\n…' + + +def _safe_run_key(call_id: str) -> str: + ts = time.strftime('%Y%m%dT%H%M%SZ', time.gmtime()) + tail = (call_id or '').strip() + tail = re.sub(r'[^a-zA-Z0-9._-]+', '_', tail)[:24] + if not tail: + tail = uuid.uuid4().hex[:12] + return f'{ts}_{tail}' + + +def _build_spill_markdown(item: Dict[str, Any]) -> str: + """Assemble full text for one result row.""" + lines: List[str] = [] + url = item.get('url', '') + title = item.get('title', '') + lines.append(f'# {title or "(no title)"}\n') + lines.append(f'**URL:** {url}\n') + summary = item.get('summary') + if isinstance(summary, str) and summary.strip(): + lines.append('\n## Summary (search snippet)\n\n') + lines.append(summary) + lines.append('\n') + content = item.get('content') + if isinstance(content, str) and content.strip(): + lines.append('\n## Content\n\n') + lines.append(content) + lines.append('\n') + abstract = item.get('abstract') + if isinstance(abstract, str) and abstract.strip(): + lines.append('\n## Abstract\n\n') + lines.append(abstract) + lines.append('\n') + chunks = item.get('chunks') + if isinstance(chunks, list) and chunks: + lines.append('\n## Chunks\n\n') + for c in chunks: + if not isinstance(c, dict): + continue + cid = c.get('chunk_id', '') + body = c.get('content', '') + lines.append(f'### chunk `{cid}`\n\n') + if isinstance(body, str): + lines.append(body) + lines.append('\n') + return ''.join(lines) + + +def _shrink_item_after_spill(item: Dict[str, Any], + spill_preview_chars: int) -> Dict[str, Any]: + """Replace heavy fields with short previews + pointers.""" + out = dict(item) + note = ( + 'Full text spilled to disk; see content_path / manifest_path in parent ' + 'JSON spill block. Use read_file on content_path for this row.') + sm = out.get('summary') + if isinstance(sm, str) and sm.strip(): + out['summary'] = _preview(sm, spill_preview_chars) + out.setdefault('content_note', note) + main = (out.get('content') or '') + if isinstance(main, str) and main.strip(): + out['content'] = _preview(main, spill_preview_chars) + out['content_note'] = note + ab = out.get('abstract') + if isinstance(ab, str) and ab.strip(): + out['abstract'] = _preview(ab, min(800, spill_preview_chars)) + ch = out.get('chunks') + if isinstance(ch, list) and ch: + out['chunks'] = [{ + 'chunk_id': + c.get('chunk_id', ''), + 'content': + _preview(str(c.get('content', '')), min(400, spill_preview_chars)), + } for c in ch if isinstance(c, dict)] + out['chunks_note'] = 'Full chunk bodies are in the spilled markdown file.' + return out + + +def maybe_spill_web_search_payload( + *, + output_dir: str, + spill_subdir: str, + spill_max_inline_chars: int, + spill_preview_chars: int, + query: str, + engine: str, + results: List[Dict[str, Any]], + call_id: str, +) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """ + If total inline chars exceed threshold, spill largest rows first until under. + + Returns: + (possibly_mutated_results, spill_meta_dict) + spill_meta_dict is empty if no spill occurred. + """ + if not output_dir or not results: + return results, {} + + work = copy.deepcopy(results) + total = _total_inline_chars(work) + if total <= spill_max_inline_chars: + return results, {} + + run_key = _safe_run_key(call_id) + root = os.path.abspath(os.path.join(output_dir, spill_subdir, run_key)) + bodies_dir = os.path.join(root, 'bodies') + os.makedirs(bodies_dir, exist_ok=True) + + spilled_indices: List[int] = [] + manifest_rows: List[Dict[str, Any]] = [] + + def order_by_size() -> List[int]: + sizes = [(i, _item_inline_chars(work[i])) for i in range(len(work))] + sizes.sort(key=lambda x: x[1], reverse=True) + return [i for i, sz in sizes if sz > 0] + + while _total_inline_chars(work) > spill_max_inline_chars: + order = order_by_size() + if not order: + break + idx = order[0] + item = work[idx] + if _item_inline_chars(item) == 0: + break + full_md = _build_spill_markdown(item) + rel_body = os.path.join(spill_subdir, run_key, 'bodies', + f'{idx:03d}.md').replace('\\', '/') + abs_body = os.path.normpath( + os.path.join(output_dir, rel_body.replace('/', os.sep))) + os.makedirs(os.path.dirname(abs_body), exist_ok=True) + header = ( + f'\n') + with open(abs_body, 'w', encoding='utf-8') as bf: + bf.write(header + full_md) + + spilled_indices.append(idx) + before_chars = _item_inline_chars(item) + work[idx] = _shrink_item_after_spill(item, spill_preview_chars) + work[idx]['content_spilled'] = True + work[idx]['content_path'] = rel_body + work[idx]['content_chars_spilled'] = before_chars + + preview_src = (item.get('content') or item.get('summary') + or item.get('abstract') or '')[:4000] + manifest_rows.append({ + 'index': + idx, + 'url': + item.get('url', ''), + 'title': + item.get('title', ''), + 'body_file': + f'bodies/{idx:03d}.md', + 'content_path': + rel_body, + 'chars_spilled': + before_chars, + 'preview': + _preview(preview_src, min(500, spill_preview_chars)), + }) + + manifest: Dict[str, Any] = { + 'version': + 1, + 'created_at_utc': + time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()), + 'query': + query, + 'engine': + engine, + 'run_key': + run_key, + 'lifecycle': + ('Ephemeral: lives under this task output_dir; delete the task directory ' + 'to remove. ms-agent does not auto-prune.'), + 'inline_chars_before': + total, + 'inline_chars_after': + _total_inline_chars(work), + 'spill_threshold_chars': + spill_max_inline_chars, + 'spilled_row_indices': + spilled_indices, + 'rows': + manifest_rows, + } + rel_manifest = os.path.join(spill_subdir, run_key, + 'manifest.json').replace('\\', '/') + abs_manifest = os.path.normpath( + os.path.join(output_dir, rel_manifest.replace('/', os.sep))) + with open(abs_manifest, 'w', encoding='utf-8') as mf: + json.dump(manifest, mf, ensure_ascii=False, indent=2) + + lines = [ + 'Large web_search payload was written to disk under this task output_dir.', + f'- **Manifest (map of rows → files, sizes)**: `{rel_manifest}`', + f'- **Bodies**: `{spill_subdir}/{run_key}/bodies/`', + 'Read **manifest.json** first, then **read_file** on specific ' + '`bodies/NNN.md` files as needed.', + '', + '**Quick index**', + ] + for row in manifest_rows: + lines.append( + f'{row["index"]}. {row.get("title") or "(no title)"} — ' + f'`{row["content_path"]}` ({row.get("chars_spilled", 0)} chars)') + digest = '\n'.join(lines) + + spill_meta = { + 'spilled': True, + 'run_key': run_key, + 'artifact_dir': f'{spill_subdir}/{run_key}'.replace('\\', '/'), + 'manifest_path': rel_manifest, + 'digest': digest, + 'inline_chars_before_spill': total, + 'inline_chars_after_spill': _total_inline_chars(work), + } + return work, spill_meta diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 7bbecfe89..167bbe3bb 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -10,11 +10,14 @@ from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase -from ms_agent.tools.jina_reader import JinaReaderConfig, fetch_single_text +from ms_agent.tools.jina_reader import (JinaReaderConfig, + fetch_single_text_with_meta) from ms_agent.tools.search.content_optimizer import (ContentOptimizer, ContentOptimizerConfig, SearchResultReranker) from ms_agent.tools.search.search_base import ENGINE_TOOL_NAMES, SearchEngine +from ms_agent.tools.search.web_search_spill import \ + maybe_spill_web_search_payload from ms_agent.utils.logger import get_logger from ms_agent.utils.thread_util import DaemonThreadPoolExecutor @@ -23,6 +26,29 @@ MAX_FETCH_CHARS = int(os.getenv('MAX_FETCH_CHARS', 100000)) +def default_per_url_fetch_timeout_s( + fetch_timeout: float, + fetch_retries: int, + direct_fetch_timeout: float, + playwright_timeout_ms: int, +) -> float: + """ + Default asyncio cap for a single URL fetch (Jina → optional direct → optional PW). + + Must exceed a worst-case Jina-only path (``fetch_timeout`` × (retries+1) attempts + plus backoff headroom) before tier-2/3, otherwise logs show all TIMEOUT and no + ``fetch_done``. Clamped to keep runaway tabs bounded. + """ + ft = max(5.0, float(fetch_timeout)) + retries = max(0, int(fetch_retries)) + # Up to (retries+1) attempts each up to ``ft``; 1.35 leaves slack for urllib backoff. + jina_budget = ft * float(retries + 1) * 1.35 + tail = max(10.0, float(direct_fetch_timeout)) + ( + float(playwright_timeout_ms) / 1000.0) + 30.0 + raw = jina_budget + tail + return max(210.0, min(720.0, raw)) + + def _json_dumps(data: Any) -> str: import json return json.dumps(data, ensure_ascii=False, indent=2) @@ -148,10 +174,11 @@ def fetch( url: str, max_chars: Optional[int] = MAX_FETCH_CHARS ) -> Tuple[str, Dict[str, Any]]: - content = fetch_single_text(url, self.config) + content, source_meta = fetch_single_text_with_meta(url, self.config) metadata: Dict[str, Any] = { 'fetcher': 'jina_reader', 'fetched_at': time.strftime('%Y-%m-%dT%H:%M:%S'), + **source_meta, } if max_chars: @@ -171,10 +198,37 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', """Factory function to get content fetcher by type.""" if fetcher_type == 'jina_reader': config = JinaReaderConfig( - timeout=kwargs.get('timeout', 30.0), + timeout=kwargs.get('timeout', 45.0), retries=kwargs.get('retries', 3), + direct_fetch_fallback=bool( + kwargs.get('direct_fetch_fallback', True)), + direct_fetch_timeout=float( + kwargs.get('direct_fetch_timeout', 15.0)), + playwright_fetch_fallback=bool( + kwargs.get('playwright_fetch_fallback', True)), + playwright_retry_min_chars=int( + kwargs.get('playwright_retry_min_chars', 400) or 400), + playwright_timeout_ms=int( + kwargs.get('playwright_timeout_ms', 30_000) or 30_000), + playwright_settle_ms=int(kwargs.get('playwright_settle_ms', 350)), ) return JinaContentFetcher(config) + if fetcher_type == 'tavily_extract': + from ms_agent.tools.search.tavily.fetcher import TavilyExtractFetcher + return TavilyExtractFetcher( + api_key=kwargs.get('tavily_api_key'), + extract_depth=str(kwargs.get('tavily_extract_depth', 'advanced')), + format=str(kwargs.get('tavily_extract_format', 'markdown')), + timeout=float(kwargs.get('timeout', 45.0)), + chunks_per_source=int( + kwargs.get('tavily_extract_chunks_per_source', 3)), + include_images=bool( + kwargs.get('tavily_extract_include_images', False)), + include_favicon=bool( + kwargs.get('tavily_extract_include_favicon', False)), + include_usage=bool( + kwargs.get('tavily_extract_include_usage', False)), + ) # Future: add more fetchers # elif fetcher_type == 'docling': # return DoclingContentFetcher(**kwargs) @@ -182,7 +236,23 @@ def get_content_fetcher(fetcher_type: str = 'jina_reader', logger.warning( f"Unknown fetcher type '{fetcher_type}', falling back to jina_reader" ) - return JinaContentFetcher() + return JinaContentFetcher( + JinaReaderConfig( + timeout=kwargs.get('timeout', 45.0), + retries=kwargs.get('retries', 3), + direct_fetch_fallback=bool( + kwargs.get('direct_fetch_fallback', True)), + direct_fetch_timeout=float( + kwargs.get('direct_fetch_timeout', 15.0)), + playwright_fetch_fallback=bool( + kwargs.get('playwright_fetch_fallback', True)), + playwright_retry_min_chars=int( + kwargs.get('playwright_retry_min_chars', 400) or 400), + playwright_timeout_ms=int( + kwargs.get('playwright_timeout_ms', 30_000) or 30_000), + playwright_settle_ms=int( + kwargs.get('playwright_settle_ms', 350)), + )) def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: @@ -206,6 +276,9 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch return ArxivSearch + elif engine_type == 'tavily': + from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch else: logger.warning( f"Unknown search engine '{engine_type}', falling back to arxiv") @@ -221,17 +294,23 @@ def get_search_engine(engine_type: str, Args: engine_type: One of 'exa', 'serpapi', 'arxiv' - api_key: API key for the search engine (if required) - **kwargs: Additional arguments passed to engine constructor + api_key: API key for the search engine (if required). + For Exa, this can be a comma-separated string of multiple keys + to enable automatic key rotation on credit exhaustion. + **kwargs: Additional arguments passed to engine constructor. + For Exa, ``api_keys`` (list or comma-separated str) is also + accepted to supply a key pool explicitly. """ engine_type = engine_type.lower().strip() if engine_type == 'exa': from ms_agent.tools.search.exa import ExaSearch - return ExaSearch(api_key=api_key or os.getenv('EXA_API_KEY')) + return ExaSearch( + api_key=api_key or os.getenv('EXA_API_KEY'), + api_keys=kwargs.get('api_keys') or os.getenv('EXA_API_KEYS'), + ) elif engine_type in ('serpapi', 'serp', 'google', 'bing', 'baidu'): from ms_agent.tools.search.serpapi import SerpApiSearch - # Allow shorthand engine_type aliases to imply provider default_provider = ('google' if engine_type in ('serpapi', 'serp') else engine_type) return SerpApiSearch( @@ -241,6 +320,12 @@ def get_search_engine(engine_type: str, elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch return ArxivSearch() + elif engine_type == 'tavily': + from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch( + api_key=api_key or os.getenv('TAVILY_API_KEY'), + request_timeout=float(kwargs.get('request_timeout', 120.0)), + ) else: logger.warning( f"Unknown search engine '{engine_type}', falling back to arxiv") @@ -265,7 +350,7 @@ def build_search_request(engine_type: str, class WebSearchTool(ToolBase): """ Unified web search tool for agents. It can search the web and fetch page content. - - Search via multiple engines (Exa, SerpAPI, Arxiv) + - Search via multiple engines (Exa, SerpAPI, Arxiv, Tavily) - Dynamic tool definitions based on configured engines - Auto-fetch and parse page content - Configurable content fetcher (jina_reader, docling, etc.) @@ -291,12 +376,14 @@ class WebSearchTool(ToolBase): exa_api_key: $EXA_API_KEY serpapi_api_key: $SERPAPI_API_KEY fetch_content: true + # Optional: asyncio deadline per URL (omit = auto from fetch_timeout/retries). + # per_url_fetch_timeout: 0 """ SERVER_NAME = 'web_search' # Registry of supported search engines - SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv') + SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv', 'tavily') # Process-wide (class-level) usage tracking for summarization calls. # This is intentionally separate from LLMAgent usage totals. @@ -394,18 +481,46 @@ def __init__(self, config, **kwargs): 'No valid engines configured, falling back to arxiv') self._engine_types = ['arxiv'] - # API keys for each engine + # API keys for each engine. + # Exa supports a key pool: exa_api_keys (list/comma-separated) takes + # priority over the single-key fields. The value is forwarded as-is to + # ExaSearch which handles parsing into a list internally. self._api_keys = { 'exa': ( - getattr(tool_cfg, 'exa_api_key', None) + getattr(tool_cfg, 'exa_api_keys', None) + or getattr(tool_cfg, 'exa_api_key', None) or getattr(tool_cfg, 'api_key', None) # backward compat - or os.getenv('EXA_API_KEY')) - if tool_cfg else os.getenv('EXA_API_KEY'), + or os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')) + if tool_cfg else + (os.getenv('EXA_API_KEYS') or os.getenv('EXA_API_KEY')), 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) or os.getenv('SERPAPI_API_KEY')) if tool_cfg else os.getenv('SERPAPI_API_KEY'), + 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) + or os.getenv('TAVILY_API_KEY')) + if tool_cfg else os.getenv('TAVILY_API_KEY'), } + # Tavily search defaults from optional `tavily:` sub-block in YAML + self._tavily_defaults: Dict[str, Any] = {} + if tool_cfg is not None: + tv = getattr(tool_cfg, 'tavily', None) + if tv is not None: + try: + from omegaconf import OmegaConf + if OmegaConf.is_config(tv): + self._tavily_defaults = dict( + OmegaConf.to_container(tv, resolve=True)) + elif isinstance(tv, dict): + self._tavily_defaults = dict(tv) + except Exception: + if isinstance(tv, dict): + self._tavily_defaults = dict(tv) + + self._tavily_request_timeout = float( + getattr(tool_cfg, 'tavily_request_timeout', 120.0) + or 120.0) if tool_cfg else 120.0 + # SerpApi provider (google, bing, baidu) self._serpapi_provider = getattr(tool_cfg, 'serpapi_provider', 'google') if tool_cfg else 'google' @@ -418,9 +533,33 @@ def __init__(self, config, **kwargs): self._fetcher_type = getattr( tool_cfg, 'fetcher', 'jina_reader') if tool_cfg else 'jina_reader' self._fetch_timeout = float( - getattr(tool_cfg, 'fetch_timeout', 30) or 30) if tool_cfg else 30.0 + getattr(tool_cfg, 'fetch_timeout', 45) or 45) if tool_cfg else 45.0 self._fetch_retries = int(getattr(tool_cfg, 'fetch_retries', 3) or 3) if tool_cfg else 3 + self._jina_direct_fetch_fallback = bool( + getattr(tool_cfg, 'jina_direct_fetch_fallback', + True)) if tool_cfg else True + if tool_cfg is not None and hasattr(tool_cfg, + 'jina_direct_fetch_timeout'): + self._jina_direct_fetch_timeout = float( + tool_cfg.jina_direct_fetch_timeout) + else: + self._jina_direct_fetch_timeout = 15.0 + self._jina_playwright_fetch_fallback = bool( + getattr(tool_cfg, 'jina_playwright_fetch_fallback', + True)) if tool_cfg else True + self._jina_playwright_retry_min_chars = int( + getattr(tool_cfg, 'jina_playwright_retry_min_chars', 400) + or 400) if tool_cfg else 400 + self._jina_playwright_timeout_ms = int( + getattr(tool_cfg, 'jina_playwright_timeout_ms', 30000) + or 30000) if tool_cfg else 30000 + if tool_cfg is not None and hasattr(tool_cfg, + 'jina_playwright_settle_ms'): + self._jina_playwright_settle_ms = int( + tool_cfg.jina_playwright_settle_ms) + else: + self._jina_playwright_settle_ms = 350 self._fetch_content_default = bool( getattr(tool_cfg, 'fetch_content', True)) if tool_cfg else True @@ -437,6 +576,20 @@ def __init__(self, config, **kwargs): self._max_concurrent_fetch = int( getattr(tool_cfg, 'max_concurrent_fetch', 3) or 3) if tool_cfg else 3 + # Hard cap (seconds) per URL for asyncio.wait_for around run_in_executor. + # When hit, this URL gets empty content + fetch_error; other URLs in the + # same web_search call keep their already-fetched bodies. Set 0 to disable + # the asyncio cap (underlying urllib/Jina timeouts still apply). + if tool_cfg is not None and hasattr(tool_cfg, 'per_url_fetch_timeout'): + self._per_url_fetch_timeout_s = float( + tool_cfg.per_url_fetch_timeout) + else: + self._per_url_fetch_timeout_s = default_per_url_fetch_timeout_s( + self._fetch_timeout, + self._fetch_retries, + self._jina_direct_fetch_timeout, + self._jina_playwright_timeout_ms, + ) self._max_concurrent_summarization = int( getattr(tool_cfg, 'max_concurrent_summarization', 5) or 5) if tool_cfg else 5 @@ -464,6 +617,20 @@ def __init__(self, config, **kwargs): getattr(tool_cfg, 'summarization_timeout', 90.0) or 90.0) if tool_cfg else 90.0 + # Large payload spill (write bodies to disk; keep JSON small) + self._spill_enabled = bool( + getattr(tool_cfg, 'spill_large_results', + True)) if tool_cfg else True + self._spill_max_inline_chars = int( + getattr(tool_cfg, 'spill_max_inline_chars', 120000) + or 120000) if tool_cfg else 120000 + self._spill_subdir = str( + getattr(tool_cfg, 'spill_subdir', 'web_search_artifacts') + or 'web_search_artifacts') if tool_cfg else 'web_search_artifacts' + self._spill_preview_chars = int( + getattr(tool_cfg, 'spill_preview_chars', 600) + or 600) if tool_cfg else 600 + # Reranking config self._enable_rerank = bool(getattr(tool_cfg, 'enable_rerank', False)) if tool_cfg else False @@ -508,6 +675,11 @@ async def connect(self) -> None: api_key=self._api_keys.get('serpapi'), provider=self._serpapi_provider, ) + elif engine_type == 'tavily': + self._engines[engine_type] = engine_cls( + api_key=self._api_keys.get('tavily'), + request_timeout=self._tavily_request_timeout, + ) else: # arxiv self._engines[engine_type] = engine_cls() @@ -519,11 +691,35 @@ async def connect(self) -> None: if not self._engines: raise RuntimeError('No search engines could be initialized') - self._content_fetcher = get_content_fetcher( - self._fetcher_type, - timeout=self._fetch_timeout, - retries=self._fetch_retries, - ) + wcfg = getattr(getattr(self.config, 'tools', None), 'web_search', None) + _fk: Dict[str, Any] = { + 'timeout': self._fetch_timeout, + 'retries': self._fetch_retries, + 'tavily_api_key': self._api_keys.get('tavily'), + 'direct_fetch_fallback': self._jina_direct_fetch_fallback, + 'direct_fetch_timeout': self._jina_direct_fetch_timeout, + 'playwright_fetch_fallback': self._jina_playwright_fetch_fallback, + 'playwright_retry_min_chars': + self._jina_playwright_retry_min_chars, + 'playwright_timeout_ms': self._jina_playwright_timeout_ms, + 'playwright_settle_ms': self._jina_playwright_settle_ms, + } + if wcfg is not None: + _fk.update({ + 'tavily_extract_depth': + getattr(wcfg, 'tavily_extract_depth', 'advanced'), + 'tavily_extract_format': + getattr(wcfg, 'tavily_extract_format', 'markdown'), + 'tavily_extract_chunks_per_source': + int(getattr(wcfg, 'tavily_extract_chunks_per_source', 3) or 3), + 'tavily_extract_include_images': + bool(getattr(wcfg, 'tavily_extract_include_images', False)), + 'tavily_extract_include_favicon': + bool(getattr(wcfg, 'tavily_extract_include_favicon', False)), + 'tavily_extract_include_usage': + bool(getattr(wcfg, 'tavily_extract_include_usage', False)), + }) + self._content_fetcher = get_content_fetcher(self._fetcher_type, **_fk) # Use daemon threads: tool-call timeouts can cancel the awaiting coroutine, # but not the underlying sync network calls running in executor threads. self._executor = DaemonThreadPoolExecutor( @@ -716,6 +912,59 @@ async def _fetch_content_async(self, url: str) -> Dict[str, Any]: return await loop.run_in_executor(self._executor, self._fetch_content_sync, url) + def _url_log_preview(self, url: str, max_len: int = 220) -> str: + u = (url or '').strip() + if len(u) <= max_len: + return u + return u[:max_len] + '...' + + async def _fetch_content_async_bounded(self, url: str) -> Dict[str, Any]: + """ + Fetch one URL with optional asyncio deadline and progress logs. + + Executor threads are not cancelled on timeout; the model still receives + partial results for other URLs in the same batch. + """ + u = (url or '').strip() + preview = self._url_log_preview(u) + t0 = time.perf_counter() + cap = float(self._per_url_fetch_timeout_s) + logger.info('[web_search] fetch start url=%s', preview) + try: + if cap > 0: + out = await asyncio.wait_for( + self._fetch_content_async(u), + timeout=cap, + ) + else: + out = await self._fetch_content_async(u) + except asyncio.TimeoutError: + elapsed = time.perf_counter() - t0 + logger.warning( + '[web_search] fetch TIMEOUT url=%s elapsed=%.1fs cap=%.1fs — ' + 'this URL is dropped for this response; others are unchanged', + preview, elapsed, cap) + return { + 'url': u, + 'content': '', + 'fetch_success': False, + 'fetcher': 'web_search', + 'fetch_error': f'per_url_fetch_timeout ({cap:g}s)', + 'fetch_timed_out': True, + } + elapsed = time.perf_counter() - t0 + src = (out or {}).get('content_source') or (out or {}).get( + 'fetcher', '') or '' + ok = bool((out or {}).get('fetch_success')) + logger.info( + '[web_search] fetch done url=%s elapsed=%.2fs ok=%s source=%s', + preview, elapsed, ok, src) + return out if out is not None else { + 'url': u, + 'content': '', + 'fetch_success': False, + } + async def _fetch_multiple_async(self, urls: List[str]) -> List[Dict[str, Any]]: """Fetch multiple URLs concurrently with semaphore.""" @@ -723,23 +972,41 @@ async def _fetch_multiple_async(self, async def _bounded_fetch(url: str) -> Dict[str, Any]: async with semaphore: - return await self._fetch_content_async(url) + return await self._fetch_content_async_bounded(url) tasks = [_bounded_fetch(url) for url in urls] return await asyncio.gather(*tasks) - def _do_search(self, engine_type: str, engine: SearchEngine, - engine_cls: Type[SearchEngine], - tool_args: Dict[str, Any]) -> List[Dict[str, Any]]: - """Perform search using the specified engine and return raw results.""" + def _do_search( + self, engine_type: str, engine: SearchEngine, + engine_cls: Type[SearchEngine], + tool_args: Dict[str, + Any]) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Perform search; returns (result rows, extra top-level metadata e.g. Tavily).""" try: - # Build request using engine's method - request = engine_cls.build_request_from_args(**tool_args) + merged = dict(tool_args) + if engine_type == 'tavily' and getattr(self, '_tavily_defaults', + None): + merged = {**self._tavily_defaults, **merged} + # Keys only for engine / fetcher YAML, not TavilySearchRequest + for _k in ('request_timeout', 'tavily_extract_depth', + 'tavily_extract_format', + 'tavily_extract_chunks_per_source', + 'tavily_extract_include_images', + 'tavily_extract_include_favicon', + 'tavily_extract_include_usage'): + merged.pop(_k, None) + request = engine_cls.build_request_from_args(**merged) result = engine.search(request) - return result.to_list() + rows = result.to_list() + extra: Dict[str, Any] = {} + from ms_agent.tools.search.tavily.schema import TavilySearchResult + if isinstance(result, TavilySearchResult): + extra = result.extra_response_fields() + return rows, extra except Exception as e: logger.error(f'Search failed ({engine_type}): {e}') - return [] + return [], {} async def _execute_search(self, engine_type: str, tool_args: Dict[str, Any]) -> str: @@ -763,6 +1030,8 @@ async def _execute_search(self, engine_type: str, 'message': 'Query is required.' }) + call_id_for_spill = str(tool_args.pop('__call_id', '') or '') + # Get fetch_content preference, default to configured value fetch_content = tool_args.pop('fetch_content', self._fetch_content_default) @@ -787,20 +1056,22 @@ async def _execute_search(self, engine_type: str, # Perform search loop = asyncio.get_event_loop() - search_results = await loop.run_in_executor(self._executor, - self._do_search, - engine_type, engine, - engine_cls, tool_args) + search_results, tavily_extra = await loop.run_in_executor( + self._executor, self._do_search, engine_type, engine, engine_cls, + tool_args) if not search_results: - return _json_dumps({ + out_empty: Dict[str, Any] = { 'status': 'ok', 'query': query, 'engine': engine_type, 'count': 0, 'results': [], 'message': 'No search results found.', - }) + } + if tavily_extra: + out_empty.update(tavily_extra) + return _json_dumps(out_empty) original_count = len(search_results) @@ -816,13 +1087,25 @@ async def _execute_search(self, engine_type: str, f'Reranked {original_count} results to top {len(search_results)} ' f'for query: {query[:50]}...') - # Step 3: Fetch content for (filtered) results + # Step 3: Fetch content for (filtered) results (skip URLs already filled e.g. Tavily raw_content) + fetch_attempts = 0 + fetch_timeouts = 0 if fetch_content and self._content_fetcher: search_results = SearchResultReranker.deduplicate_by_url( search_results) - urls = [r.get('url') for r in search_results if r.get('url')] + urls: List[str] = [] + for r in search_results: + u = r.get('url') + if not u: + continue + if r.get('fetch_success') and (r.get('content') or '').strip(): + continue + urls.append(u) if urls: + fetch_attempts = len(urls) fetch_results = await self._fetch_multiple_async(urls) + fetch_timeouts = sum(1 for r in fetch_results + if r.get('fetch_timed_out')) # Merge search metadata with fetched content url_to_fetch = {r['url']: r for r in fetch_results} @@ -833,6 +1116,16 @@ async def _execute_search(self, engine_type: str, sr['content'] = fetched.get('content', '') sr['fetch_success'] = fetched.get( 'fetch_success', False) + if fetched.get('fetch_error'): + sr['fetch_error'] = fetched['fetch_error'] + else: + sr.pop('fetch_error', None) + if fetched.get('fetch_timed_out'): + sr['fetch_timed_out'] = True + else: + sr.pop('fetch_timed_out', None) + if fetched.get('content_source'): + sr['content_source'] = fetched['content_source'] if fetched.get('published_at' ) and not sr.get('published_date'): sr['published_at'] = fetched['published_at'] @@ -979,6 +1272,8 @@ async def _execute_search(self, engine_type: str, sr.get('arxiv_id', '') or '', # arXiv short id 'abs_url': sr.get('id', '') or '', # entry_id (abstract page) + 'pdf_url': + sr.get('pdf_url', '') or '', 'abstract': sr.get('summary', '') or '', 'authors': @@ -994,6 +1289,12 @@ async def _execute_search(self, engine_type: str, if fetch_content: item['content'] = sr.get('content', '') item['fetch_success'] = sr.get('fetch_success', False) + if sr.get('fetch_error'): + item['fetch_error'] = sr.get('fetch_error') + if sr.get('fetch_timed_out'): + item['fetch_timed_out'] = True + if sr.get('content_source'): + item['content_source'] = sr.get('content_source') # Add summarization metadata if applicable if sr.get('content_summarized'): item['content_summarized'] = True @@ -1006,6 +1307,14 @@ async def _execute_search(self, engine_type: str, # Include snippet if available for non-arxiv engines item['summary'] = sr.get('summary', '') + if engine_type == 'tavily': + if sr.get('score') is not None: + item['score'] = sr.get('score') + if sr.get('tavily_images'): + item['images'] = sr.get('tavily_images') + if sr.get('favicon'): + item['favicon'] = sr.get('favicon') + # Add item to results for all engines output_results.append(item) @@ -1017,6 +1326,14 @@ async def _execute_search(self, engine_type: str, 'count': len(output_results), 'results': output_results, } + if fetch_content and self._content_fetcher: + response['fetch_stats'] = { + 'per_url_timeout_s': self._per_url_fetch_timeout_s, + 'urls_fetched_this_call': fetch_attempts, + 'urls_timed_out': fetch_timeouts, + } + if tavily_extra: + response.update(tavily_extra) # Add optimization info if self._enable_rerank or self._enable_summarization: @@ -1059,6 +1376,29 @@ async def _execute_search(self, engine_type: str, 'summarization_usage_process_total' ] = WebSearchTool.get_global_summarization_usage() # yapf: disable + if self._spill_enabled: + od = getattr(self, 'output_dir', None) or getattr( + getattr(self, 'config', None), 'output_dir', '') or '' + if od: + try: + new_results, spill_meta = maybe_spill_web_search_payload( + output_dir=od, + spill_subdir=self._spill_subdir, + spill_max_inline_chars=self._spill_max_inline_chars, + spill_preview_chars=self._spill_preview_chars, + query=query, + engine=engine_type, + results=response['results'], + call_id=call_id_for_spill, + ) + if spill_meta: + response['results'] = new_results + response['spill'] = spill_meta + except Exception as e: + logger.warning( + f'web_search spill failed (returning full inline JSON): {e}' + ) + return _json_dumps(response) async def fetch_page(self, url: str) -> str: @@ -1069,7 +1409,7 @@ async def fetch_page(self, url: str) -> str: 'message': 'URL is required.' }) - result = await self._fetch_content_async(url.strip()) + result = await self._fetch_content_async_bounded(url.strip()) return _json_dumps({ 'status': @@ -1082,6 +1422,10 @@ async def fetch_page(self, url: str) -> str: result.get('published_at', ''), 'fetch_success': result.get('fetch_success', False), + 'fetch_error': + result.get('fetch_error', ''), + 'fetch_timed_out': + bool(result.get('fetch_timed_out')), 'chunks': result.get('chunks') if self._enable_chunking else None, }) @@ -1122,3 +1466,7 @@ async def arxiv_search(self, **kwargs) -> str: async def serpapi_search(self, **kwargs) -> str: """Search using SerpApi engine.""" return await self._execute_search('serpapi', kwargs) + + async def tavily_search(self, **kwargs) -> str: + """Search using Tavily engine.""" + return await self._execute_search('tavily', kwargs) diff --git a/ms_agent/tools/search_engine.py b/ms_agent/tools/search_engine.py index 6bb0be4d6..66e05de8e 100644 --- a/ms_agent/tools/search_engine.py +++ b/ms_agent/tools/search_engine.py @@ -1,8 +1,8 @@ import os import threading +from dotenv import load_dotenv from typing import Any, Dict, Optional -from dotenv import load_dotenv from ms_agent.config.env import Env from ms_agent.tools.search.arxiv import ArxivSearch from ms_agent.tools.search.exa import ExaSearch diff --git a/ms_agent/tools/split_task.py b/ms_agent/tools/split_task.py deleted file mode 100644 index ea1e1c086..000000000 --- a/ms_agent/tools/split_task.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import asyncio -from concurrent.futures import ThreadPoolExecutor, as_completed - -from ms_agent.llm.utils import Tool -from ms_agent.tools.base import ToolBase -from ms_agent.utils.utils import escape_yaml_string -from omegaconf import DictConfig - - -class SplitTask(ToolBase): - """A tool special for task splitting""" - - def __init__(self, config: DictConfig, **kwargs): - super().__init__(config) - if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'): - self.tag_prefix = getattr(config.tools.split_task, 'tag_prefix', - 'worker-') - else: - self.tag_prefix = kwargs.get('tag_prefix', 'worker-') - self.round = 0 - - async def connect(self): - pass - - async def cleanup(self): - pass - - async def _get_tools_inner(self): - return { - 'split_task': [ - Tool( - tool_name='split_to_sub_task', - server_name='split_task', - description= - 'Split complex task into sub tasks and start them, for example, ' - 'split a website generation task into sub tasks, ' - 'you plan the framework, include code files and classes and functions, and give the detail ' - 'information to the system and query field of the subtask, then ' - 'let each subtask to write a single file', - parameters={ - 'type': 'object', - 'properties': { - 'tasks': { - 'type': - 'array', - 'description': - 'MANDATORY: Each element is a dict, which must contains two fields: ' - '`system`(str) and `query`(str) to start one sub task.' - } - }, - 'required': ['tasks'], - 'additionalProperties': False - }) - ] - } - - async def call_tool(self, server_name: str, *, tool_name: str, - tool_args: dict): - """ - 1. LLMAgent will be used to start subtask - 2. config will be inherited from the parent task - 3. Supports both parallel and sequential execution modes - """ - from ms_agent.agent import LLMAgent - - tasks = tool_args.get('tasks') - execution_mode = tool_args.get( - 'execution_mode', 'sequential') # 'parallel' or 'sequential' - - def run_agent_sync(i, task): - """Synchronous wrapper for agent execution""" - system = task['system'] - query = task['query'] - config = DictConfig(self.config) - if not hasattr(config, 'prompt'): - config.prompt = DictConfig({}) - config.prompt.system = escape_yaml_string(system) - trust_remote_code = getattr(config, 'trust_remote_code', False) - agent = LLMAgent( - config=config, - trust_remote_code=trust_remote_code, - tag=f'{config.tag}-r{self.round}-{self.tag_prefix}{i}', - load_cache=getattr(config, 'load_cache', False)) - - # Run async agent.run() in sync context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(agent.run(query)) - finally: - loop.close() - - result = [] - if execution_mode == 'parallel': - # Use ThreadPoolExecutor for parallel execution - with ThreadPoolExecutor() as executor: - futures = { - executor.submit(run_agent_sync, i, task): i - for i, task in enumerate(tasks) - } - - # Collect results as they complete - for future in as_completed(futures): - i = futures[future] - try: - r = future.result() - result.append((i, r)) - except Exception as e: - result.append( - (i, f'Subtask{i} failed with error: {e}')) - - # Sort by task index to maintain order - result.sort(key=lambda x: x[0]) - result = [r[1] for r in result] - else: # sequential - for i, task in enumerate(tasks): - try: - r = run_agent_sync(i, task) - result.append(r) - except Exception as e: - result.append(f'Subtask{i} failed with error: {e}') - - res = [] - for messages in result: - if isinstance(messages, list): - content = messages[-1].content - if len(content) > 2048: - content = content[:2048] - else: - content = str(messages) - res.append(content) - - self.round += 1 - - formatted_result = '' - for i in range(len(res)): - formatted_result += f'SplitTask{i}:{res[i]}\n' - - return formatted_result diff --git a/ms_agent/tools/task_control_tool.py b/ms_agent/tools/task_control_tool.py new file mode 100644 index 000000000..a0c781e94 --- /dev/null +++ b/ms_agent/tools/task_control_tool.py @@ -0,0 +1,117 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import json +from omegaconf import DictConfig +from typing import Any, Dict, Optional + +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_SERVER = 'task_control' + + +class TaskControlTool(ToolBase): + """Exposes background task management to the LLM. + + Provides three tools: + - list_tasks: show all background tasks and their status + - cancel_task: kill a running background task by task_id + - push_to_background: escape a currently-blocking sync agent call to background + """ + + def __init__(self, config: DictConfig, **kwargs): + super().__init__(config) + self._task_manager = None + self._agent_tool = None + + def set_task_manager(self, task_manager) -> None: + self._task_manager = task_manager + + def set_agent_tool(self, agent_tool) -> None: + """Wire up the AgentTool so push_to_background can call escape_to_background.""" + self._agent_tool = agent_tool + + async def connect(self) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def _get_tools_inner(self) -> Dict[str, Any]: + return { + _SERVER: [ + Tool( + tool_name='list_tasks', + server_name=_SERVER, + description= + ('List all background tasks and their current status. ' + 'Returns task_id, tool_name, description, status, and duration.' + ), + parameters={ + 'type': 'object', + 'properties': {}, + 'required': [], + 'additionalProperties': False, + }, + ), + Tool( + tool_name='cancel_task', + server_name=_SERVER, + description= + 'Cancel a running background task by its task_id.', + parameters={ + 'type': 'object', + 'properties': { + 'task_id': { + 'type': + 'string', + 'description': + 'The task_id returned by the async_launched response.', + } + }, + 'required': ['task_id'], + 'additionalProperties': False, + }, + ), + ] + } + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: + if self._task_manager is None: + return 'TaskManager not available.' + + if tool_name == 'list_tasks': + tasks = list(self._task_manager._tasks.values()) + if not tasks: + return 'No background tasks registered.' + rows = [] + for t in tasks: + duration = '' + if t.ended_at: + duration = f'{t.ended_at - t.started_at:.1f}s' + elif t.status == 'running': + import time + duration = f'{time.monotonic() - t.started_at:.1f}s (running)' + rows.append({ + 'task_id': t.task_id, + 'tool_name': t.tool_name, + 'description': t.description, + 'status': t.status, + 'duration': duration, + }) + return json.dumps(rows, ensure_ascii=False, indent=2) + + if tool_name == 'cancel_task': + task_id = tool_args.get('task_id', '') + task = self._task_manager.get_task(task_id) + if task is None: + return f'Task "{task_id}" not found.' + if task.status != 'running': + return f'Task "{task_id}" is already {task.status}.' + self._task_manager.kill(task_id) + return f'Task "{task_id}" cancelled.' + + return f'Unknown tool: {tool_name}' diff --git a/ms_agent/tools/todolist_tool.py b/ms_agent/tools/todolist_tool.py index aee860134..a2ae94bb5 100644 --- a/ms_agent/tools/todolist_tool.py +++ b/ms_agent/tools/todolist_tool.py @@ -1,9 +1,9 @@ +import json import os import time from dataclasses import dataclass from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock, render_markdown_todo diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 58f019774..161177088 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -2,6 +2,7 @@ import asyncio import importlib import inspect +import json import os import sys import uuid @@ -9,7 +10,6 @@ from types import TracebackType from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool, ToolCall from ms_agent.tools.agent_tool import AgentTool from ms_agent.tools.base import ToolBase @@ -17,8 +17,10 @@ from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.tools.image_generator import ImageGenerator from ms_agent.tools.mcp_client import MCPClient +from ms_agent.tools.search.localsearch_tool import LocalSearchTool +from ms_agent.tools.search.sirchmunk_search import \ + effective_localsearch_settings from ms_agent.tools.search.websearch_tool import WebSearchTool -from ms_agent.tools.split_task import SplitTask from ms_agent.tools.todolist_tool import TodoListTool from ms_agent.tools.video_generator import VideoGenerator from ms_agent.utils import get_logger @@ -47,8 +49,6 @@ def __init__(self, self.extra_tools: List[ToolBase] = [] self.has_split_task_tool = False - if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'): - self.extra_tools.append(SplitTask(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'image_generator'): self.extra_tools.append(ImageGenerator(config)) @@ -76,10 +76,12 @@ def __init__(self, self.extra_tools.append(CodeExecutionTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'financial_data_fetcher'): - from ms_agent.tools.findata.findata_fetcher import FinancialDataFetcher + from ms_agent.tools.findata.findata_fetcher import \ + FinancialDataFetcher self.extra_tools.append(FinancialDataFetcher(config)) - if hasattr(config, 'tools') and getattr(config.tools, 'agent_tools', - None): + if hasattr(config, + 'tools') and (getattr(config.tools, 'agent_tools', None) + or hasattr(config.tools, 'split_task')): agent_tool = AgentTool( config, trust_remote_code=self.trust_remote_code) if agent_tool.enabled: @@ -88,6 +90,11 @@ def __init__(self, self.extra_tools.append(TodoListTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'web_search'): self.extra_tools.append(WebSearchTool(config)) + if effective_localsearch_settings(config) is not None: + self.extra_tools.append(LocalSearchTool(config)) + if hasattr(config, 'tools') and hasattr(config.tools, 'task_control'): + from ms_agent.tools.task_control_tool import TaskControlTool + self.extra_tools.append(TaskControlTool(config)) self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) local_dir = self.config.local_dir if hasattr(self.config, @@ -226,6 +233,12 @@ async def single_call_tool(self, tool_info: ToolCall): call_args = dict(tool_args or {}) call_id = tool_info.get('id') or str(uuid.uuid4()) call_args['__call_id'] = call_id + elif isinstance(tool_ins, + LocalCodeExecutionTool) and tool_name.endswith( + f'{self.TOOL_SPLITER}shell_executor'): + call_args = dict(tool_args or {}) + call_args['__call_id'] = tool_info.get('id') or str( + uuid.uuid4()) response = await asyncio.wait_for( tool_ins.call_tool( server_name, diff --git a/ms_agent/utils/artifact_manager.py b/ms_agent/utils/artifact_manager.py new file mode 100644 index 000000000..2d73457f0 --- /dev/null +++ b/ms_agent/utils/artifact_manager.py @@ -0,0 +1,121 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Spill large tool outputs to disk under output_dir/.ms_agent_artifacts/.""" + +from __future__ import annotations + +import hashlib +import json +from pathlib import Path +from typing import Any + + +class ArtifactManager: + """When combined stdout+stderr exceeds *max_combined_bytes*, write to artifact file.""" + + def __init__( + self, + output_dir: Path | str, + *, + max_combined_bytes: int = 256 * 1024, + preview_head_chars: int = 4000, + preview_tail_chars: int = 2000, + artifact_subdir: str = '.ms_agent_artifacts', + ) -> None: + self._root = Path(output_dir).expanduser().resolve() + self.max_combined_bytes = max_combined_bytes + self.preview_head_chars = preview_head_chars + self.preview_tail_chars = preview_tail_chars + self._artifact_root = self._root / artifact_subdir + + def pack_text_result( + self, + *, + tool_name: str, + call_id: str, + stdout: str, + stderr: str, + extra: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Return payload fields: output, error, truncated, artifact_path (optional).""" + combined = (stdout or '') + (stderr or '') + enc = combined.encode('utf-8', errors='replace') + if len(enc) <= self.max_combined_bytes: + out: dict[str, Any] = { + 'output': stdout, + 'error': stderr or None, + 'truncated': False, + } + if extra: + out.update(extra) + return out + + safe_id = ''.join(c if c.isalnum() or c in '-_' else '_' + for c in call_id)[:120] or 'call' + rel_dir = Path(tool_name) / safe_id + out_dir = self._artifact_root / rel_dir + out_dir.mkdir(parents=True, exist_ok=True) + body = f'=== STDOUT ===\n{stdout}\n\n=== STDERR ===\n{stderr}\n' + digest = hashlib.sha256(enc).hexdigest()[:16] + fname = f'combined-{digest}.txt' + fpath = out_dir / fname + fpath.write_text(body, encoding='utf-8', errors='replace') + rel = fpath.relative_to(self._root).as_posix() + preview = _make_preview(body, self.preview_head_chars, + self.preview_tail_chars) + result = { + 'output': + stdout[:self.preview_head_chars] + if len(stdout) > self.preview_head_chars else stdout, + 'error': (stderr[:self.preview_head_chars] if stderr else None), + 'truncated': + True, + 'artifact_path': + rel, + 'preview': + preview, + 'artifact_bytes': + len(enc), + } + if extra: + result.update(extra) + return result + + def pack_json_shell_result( + self, + *, + tool_name: str, + call_id: str, + payload: dict[str, Any], + ) -> str: + """JSON-encode *payload* after applying spill rules to output/error string fields.""" + stdout = str(payload.get('output') or '') + stderr = str(payload.get('error') or '') + packed = self.pack_text_result( + tool_name=tool_name, + call_id=call_id, + stdout=stdout, + stderr=stderr, + extra={ + k: v + for k, v in payload.items() if k not in ('output', 'error') + }, + ) + # pack_text_result merged extra into top level; rebuild standard shell shape + out = { + 'success': payload.get('success'), + 'output': packed.get('output'), + 'error': packed.get('error'), + 'return_code': payload.get('return_code'), + 'truncated': packed.get('truncated', False), + } + if packed.get('artifact_path'): + out['artifact_path'] = packed['artifact_path'] + out['preview'] = packed.get('preview') + out['artifact_bytes'] = packed.get('artifact_bytes') + return json.dumps(out, ensure_ascii=False, indent=2) + + +def _make_preview(text: str, head: int, tail: int) -> str: + if len(text) <= head + tail: + return text + return (text[:head] + '\n... [truncated] ...\n' + text[-tail:]) diff --git a/ms_agent/utils/parser_utils.py b/ms_agent/utils/parser_utils.py index 5455dae52..af22034c6 100644 --- a/ms_agent/utils/parser_utils.py +++ b/ms_agent/utils/parser_utils.py @@ -1,11 +1,10 @@ +import json import os import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict, List, Optional -import json - @dataclass class ImportInfo: diff --git a/ms_agent/utils/push_to_hub.py b/ms_agent/utils/push_to_hub.py index a164603dd..2cc6d7b7a 100644 --- a/ms_agent/utils/push_to_hub.py +++ b/ms_agent/utils/push_to_hub.py @@ -1,15 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import mimetypes import os import re +import requests import shutil from abc import ABC, abstractmethod from pathlib import Path from typing import List, Optional, Tuple -import json -import requests from ms_agent.utils.logger import get_logger from ms_agent.utils.utils import (get_files_from_dir, is_package_installed, text_hash) @@ -353,8 +353,7 @@ def __init__( 'ModelScope package is not installed. Please install it with `pip install modelscope`.' ) - from modelscope.hub.api import HubApi - from modelscope.hub.api import get_endpoint + from modelscope.hub.api import HubApi, get_endpoint self.api = HubApi() self.token = token diff --git a/ms_agent/utils/snapshot.py b/ms_agent/utils/snapshot.py new file mode 100644 index 000000000..d7362405a --- /dev/null +++ b/ms_agent/utils/snapshot.py @@ -0,0 +1,234 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Lightweight snapshot utility for ms-agent output directories. + +Uses a dedicated git repo stored at /.ms_agent_snapshots/ +so it never touches or conflicts with the user's own .git directory. + +All git commands are run with GIT_DIR and GIT_WORK_TREE explicitly set, +so the snapshot repo is fully isolated from any surrounding repository. +""" +import json +import os +import subprocess +from typing import Optional + +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_SNAPSHOT_DIR_NAME = '.ms_agent_snapshots' +_META_FILE = 'snapshot_meta.json' + + +def _git(args: list[str], + work_tree: str, + git_dir: str, + check: bool = True) -> subprocess.CompletedProcess: + env = os.environ.copy() + env['GIT_DIR'] = git_dir + env['GIT_WORK_TREE'] = work_tree + # Suppress interactive prompts + env['GIT_TERMINAL_PROMPT'] = '0' + return subprocess.run( + ['git'] + args, + env=env, + cwd=work_tree, + capture_output=True, + text=True, + check=check, + ) + + +def _snapshot_git_dir(output_dir: str) -> str: + return os.path.join(output_dir, _SNAPSHOT_DIR_NAME) + + +def _configure_snapshot_repo_for_automation(work_tree: str, + git_dir: str) -> None: + """Disable hook execution for the nested snapshot repo. + + Without this, Git can inherit ``init.templateDir`` / global ``core.hooksPath`` + (e.g. lefthook), so ``git commit`` runs hooks and races under concurrency + (``cannot lock ref 'HEAD'`` / hook failures). ``os.devnull`` is the portable + Git-supported way to disable hooks (POSIX ``/dev/null``, Windows ``nul``). + """ + try: + _git(['config', 'core.hooksPath', os.devnull], + work_tree=work_tree, + git_dir=git_dir, + check=False) + except Exception: + pass + + +def _ensure_repo(output_dir: str) -> str: + """Initialize the snapshot repo if it doesn't exist. Returns git_dir.""" + git_dir = _snapshot_git_dir(output_dir) + if not os.path.isdir(git_dir): + os.makedirs(git_dir, exist_ok=True) + # Use non-bare init with explicit GIT_DIR — no --bare so work tree is supported. + # Do NOT pass a path argument; GIT_DIR env var points git at our custom dir. + _git(['init'], work_tree=output_dir, git_dir=git_dir) + _git(['config', 'user.email', 'ms-agent@snapshot'], + work_tree=output_dir, + git_dir=git_dir) + _git(['config', 'user.name', 'ms-agent'], + work_tree=output_dir, + git_dir=git_dir) + # Exclude the snapshot dir itself from tracking + info_dir = os.path.join(git_dir, 'info') + os.makedirs(info_dir, exist_ok=True) + exclude_file = os.path.join(info_dir, 'exclude') + with open(exclude_file, 'a', encoding='utf-8') as f: + f.write(f'\n{_SNAPSHOT_DIR_NAME}/\n') + # Always (re)apply: repos created before this fix may still inherit hooks. + _configure_snapshot_repo_for_automation(output_dir, git_dir) + return git_dir + + +def _meta_path(output_dir: str) -> str: + return os.path.join(_snapshot_git_dir(output_dir), _META_FILE) + + +def _load_meta(output_dir: str) -> dict: + path = _meta_path(output_dir) + if os.path.exists(path): + try: + with open(path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception: + pass + return {} + + +def _save_meta(output_dir: str, meta: dict) -> None: + path = _meta_path(output_dir) + with open(path, 'w', encoding='utf-8') as f: + json.dump(meta, f, indent=2) + + +def take_snapshot(output_dir: str, + message: str, + message_count: int = 0) -> Optional[str]: + """ + Stage all changes in output_dir and create a snapshot commit. + + Args: + output_dir: The directory to snapshot. + message: Commit message (truncated to 120 chars). + message_count: Number of messages in history at snapshot time. + Stored in metadata so rollback can truncate history. + + Returns the short commit hash on success, or None if nothing to commit + or if git is unavailable. + """ + if not output_dir or not os.path.isdir(output_dir): + return None + + try: + git_dir = _ensure_repo(output_dir) + + # Stage everything (excluding .ms_agent_snapshots via info/exclude) + _git(['add', '-A'], work_tree=output_dir, git_dir=git_dir) + + # Check if there's anything to commit + status = _git(['status', '--porcelain'], + work_tree=output_dir, + git_dir=git_dir) + if not status.stdout.strip(): + return None # Nothing changed + + # Truncate message to keep commit subject readable + subject = message.strip().replace('\n', ' ')[:120] + result = _git(['commit', '--no-verify', '-m', subject], + work_tree=output_dir, + git_dir=git_dir) + + commit_hash = None + for line in result.stdout.splitlines(): + if line.startswith('['): + before_bracket = line.split(']')[0] + commit_hash = before_bracket.split()[-1] + break + if commit_hash is None: + commit_hash = 'ok' + + # Persist message_count so rollback can truncate history + meta = _load_meta(output_dir) + meta[commit_hash] = {'message_count': message_count} + _save_meta(output_dir, meta) + + return commit_hash + + except FileNotFoundError: + logger.warning_once('[snapshot] git not found — snapshots disabled.') + return None + except subprocess.CalledProcessError as e: + logger.warning(f'[snapshot] git error: {e.stderr.strip()}') + return None + except Exception as e: + logger.warning(f'[snapshot] unexpected error: {e}') + return None + + +def list_snapshots(output_dir: str) -> list[dict]: + """ + Return a list of snapshots as dicts with keys: hash, message, date, message_count. + Most recent first. + """ + git_dir = _snapshot_git_dir(output_dir) + if not os.path.isdir(git_dir): + return [] + try: + result = _git( + ['log', '--pretty=format:%h\t%ai\t%s'], + work_tree=output_dir, + git_dir=git_dir, + check=False, + ) + if result.returncode != 0: + return [] + meta = _load_meta(output_dir) + snapshots = [] + for line in result.stdout.splitlines(): + parts = line.split('\t', 2) + if len(parts) == 3: + h = parts[0] + snapshots.append({ + 'hash': + h, + 'date': + parts[1], + 'message': + parts[2], + 'message_count': + meta.get(h, {}).get('message_count', 0), + }) + return snapshots + except Exception: + return [] + + +def restore_snapshot(output_dir: str, commit_hash: str) -> tuple[bool, int]: + """ + Restore output_dir to the state at commit_hash. + + Returns (success, message_count) where message_count is the number of + messages in history at snapshot time (0 if unknown). + """ + git_dir = _snapshot_git_dir(output_dir) + if not os.path.isdir(git_dir): + logger.warning('[snapshot] No snapshot repo found.') + return False, 0 + try: + _git(['checkout', commit_hash, '--', '.'], + work_tree=output_dir, + git_dir=git_dir) + logger.info(f'[snapshot] Restored to {commit_hash}') + meta = _load_meta(output_dir) + message_count = meta.get(commit_hash, {}).get('message_count', 0) + return True, message_count + except subprocess.CalledProcessError as e: + logger.warning(f'[snapshot] restore failed: {e.stderr.strip()}') + return False, 0 diff --git a/ms_agent/utils/stats.py b/ms_agent/utils/stats.py index 7ed705a1d..beeccfc1f 100644 --- a/ms_agent/utils/stats.py +++ b/ms_agent/utils/stats.py @@ -1,13 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio +import json import os import time from datetime import datetime from typing import Any, Dict, Iterable, Optional -import json from ms_agent.llm.utils import Message - from .logger import get_logger logger = get_logger() diff --git a/ms_agent/utils/stream_writer.py b/ms_agent/utils/stream_writer.py new file mode 100644 index 000000000..cdeddec77 --- /dev/null +++ b/ms_agent/utils/stream_writer.py @@ -0,0 +1,210 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +SubAgentStreamWriter — incremental JSONL writer for sub-agent execution progress. + +When a parent agent calls a sub-agent via AgentTool, the sub-agent produces messages +incrementally. This writer appends new messages to a JSONL file on every chunk event, +allowing external tools (e.g. ``tail -f``) or the parent agent itself to watch the +sub-agent's progress in real time. + +File format (one JSON object per line): + +.. code-block:: text + + {"type": "header", "call_id": "...", "tool_name": "...", "agent_tag": "...", "ts": "..."} + {"type": "message", "index": 0, "message": {...}, "ts": "..."} + {"type": "message", "index": 1, "message": {...}, "ts": "..."} + ... + {"type": "footer", "call_id": "...", "status": "complete", "total_messages": N, "ts": "..."} + +On error the footer's *status* field is ``"error"`` and an ``"error"`` field is included. + +File path: ``{output_dir}/subagents/{call_id}.stream.jsonl`` +""" +import json +import os +import threading +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from ms_agent.utils import get_logger + +logger = get_logger() + + +class SubAgentStreamWriter: + """Thread-safe incremental JSONL writer for sub-agent chunk events. + + Each instance owns exactly one file. It deduplicates history by tracking + how many messages have already been written (``_last_written_count``), so + calling ``on_chunk(full_history)`` multiple times is safe — only the newly + appended messages are written. + + All public methods are safe to call from multiple threads. + """ + + def __init__(self, output_dir: str, call_id: str, tool_name: str) -> None: + self._call_id: str = call_id or 'unknown' + self._tool_name: str = tool_name + self._lock = threading.Lock() + self._last_written_count: int = 0 + self._closed: bool = False + self._agent_tag: Optional[str] = None + self._file = None # opened lazily in on_start + + subagents_dir = os.path.join(output_dir, 'subagents') + os.makedirs(subagents_dir, exist_ok=True) + safe_id = self._call_id.replace('/', '_').replace('\\', '_') + self._path: str = os.path.join(subagents_dir, + f'{safe_id}.stream.jsonl') + + @property + def stream_path(self) -> str: + """Absolute path to the JSONL stream file.""" + return self._path + + def on_start(self, agent_tag: Optional[str]) -> None: + """Open the file and write the header record. + + Args: + agent_tag: The sub-agent's tag string, if known at start time. + May be ``None`` when running in a subprocess (tag is + only resolved after the process finishes). + """ + with self._lock: + if self._closed: + return + self._agent_tag = agent_tag + try: + self._file = open(self._path, 'w', encoding='utf-8') + self._write_line({ + 'type': 'header', + 'call_id': self._call_id, + 'tool_name': self._tool_name, + 'agent_tag': agent_tag or '', + 'ts': _now_iso(), + }) + except Exception as exc: + logger.warning('SubAgentStreamWriter: failed to open %s: %s', + self._path, exc) + self._file = None + + def on_chunk(self, history: Any) -> None: + """Append only new messages from *history* since the last call. + + Args: + history: The full accumulated message list returned by a streaming + chunk. May be ``None`` or an empty list; in that case + nothing is written. + """ + messages = _coerce_to_list(history) + if not messages: + return + with self._lock: + if self._closed or self._file is None: + return + for msg in messages[self._last_written_count:]: + self._write_line({ + 'type': 'message', + 'index': self._last_written_count, + 'message': _msg_to_dict(msg), + 'ts': _now_iso(), + }) + self._last_written_count += 1 + + def on_end(self, history: Any) -> None: + """Flush any remaining messages, write footer record, then close. + + Args: + history: Final full message list (same shape as ``on_chunk``). + """ + # Write any messages that arrived in the final chunk before closing. + self.on_chunk(history) + with self._lock: + if self._closed: + return + self._closed = True + if self._file is not None: + try: + self._write_line({ + 'type': 'footer', + 'call_id': self._call_id, + 'agent_tag': self._agent_tag or '', + 'status': 'complete', + 'total_messages': self._last_written_count, + 'ts': _now_iso(), + }) + self._file.flush() + self._file.close() + except Exception as exc: + logger.warning( + 'SubAgentStreamWriter: close error on %s: %s', + self._path, exc) + finally: + self._file = None + + def on_error(self, error: str) -> None: + """Write an error footer and close the file. + + Args: + error: Human-readable error description. + """ + with self._lock: + if self._closed: + return + self._closed = True + if self._file is not None: + try: + self._write_line({ + 'type': 'footer', + 'call_id': self._call_id, + 'agent_tag': self._agent_tag or '', + 'status': 'error', + 'error': error, + 'total_messages': self._last_written_count, + 'ts': _now_iso(), + }) + self._file.flush() + self._file.close() + except Exception: + pass + finally: + self._file = None + + # ── private helpers ──────────────────────────────────────────────────── + + def _write_line(self, record: Dict[str, Any]) -> None: + """Serialize *record* as JSON and append a newline. + + Caller **must** hold ``self._lock``. Each line is flushed immediately + so that ``tail -f`` sees it without buffering. + """ + if self._file is None: + return + try: + self._file.write(json.dumps(record, ensure_ascii=False) + '\n') + self._file.flush() + except Exception as exc: + logger.warning('SubAgentStreamWriter: write failed: %s', exc) + + +# ── module-level helpers ──────────────────────────────────────────────────── + + +def _now_iso() -> str: + """Return the current UTC time as an ISO-8601 string.""" + return datetime.now(timezone.utc).isoformat() + + +def _coerce_to_list(value: Any) -> List[Any]: + """Return *value* if it is a list, otherwise an empty list.""" + return value if isinstance(value, list) else [] + + +def _msg_to_dict(msg: Any) -> Dict[str, Any]: + """Convert a Message object (or plain dict) to a serialisable dict.""" + if hasattr(msg, 'to_dict'): + return msg.to_dict() + if isinstance(msg, dict): + return msg + return {'role': 'unknown', 'content': str(msg)} diff --git a/ms_agent/utils/task_manager.py b/ms_agent/utils/task_manager.py new file mode 100644 index 000000000..6edc00280 --- /dev/null +++ b/ms_agent/utils/task_manager.py @@ -0,0 +1,134 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import asyncio +import multiprocessing as mp +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +@dataclass +class BackgroundTask: + task_id: str + task_type: str # 'agent' | 'shell' + tool_name: str # which tool spawned this + description: str + status: str = 'running' # 'running' | 'completed' | 'failed' | 'killed' + proc: Optional[Any] = field( + default=None, repr=False) # mp.Process or asyncio.Task + result: Optional[str] = None + error: Optional[str] = None + started_at: float = field(default_factory=time.monotonic) + ended_at: Optional[float] = None + + +class TaskManager: + """ + Agent-level registry for background tasks (agent sub-tasks, shell tasks, etc.). + Holds a notification queue that LLMAgent drains each turn to inject + completion notices into the conversation. + """ + + def __init__(self): + self._tasks: Dict[str, BackgroundTask] = {} + self._lock = asyncio.Lock() + self._notification_queue: asyncio.Queue = asyncio.Queue() + + def register( + self, + task_type: str, + tool_name: str, + description: str, + proc: Optional[Any] = None, + task_id: Optional[str] = None, + ) -> str: + task_id = task_id or uuid.uuid4().hex[:12] + task = BackgroundTask( + task_id=task_id, + task_type=task_type, + tool_name=tool_name, + description=description, + proc=proc, + ) + self._tasks[task_id] = task + logger.info( + f'[TaskManager] registered {task_type} task {task_id}: {description}' + ) + return task_id + + async def complete(self, task_id: str, result: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + task.status = 'completed' + task.result = result + task.ended_at = time.monotonic() + await self._notification_queue.put(self._format_notification(task)) + + async def fail(self, task_id: str, error: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + task.status = 'failed' + task.error = error + task.ended_at = time.monotonic() + await self._notification_queue.put(self._format_notification(task)) + + def kill(self, task_id: str) -> None: + task = self._tasks.get(task_id) + if task is None: + return + if task.status != 'running': + return + if task.proc is not None: + try: + if isinstance(task.proc, mp.Process): + task.proc.terminate() + elif asyncio.isfuture(task.proc) or asyncio.iscoroutine( + task.proc): + task.proc.cancel() + except Exception as e: + logger.warning(f'[TaskManager] kill {task_id} failed: {e}') + task.status = 'killed' + task.ended_at = time.monotonic() + + def kill_all(self) -> None: + for task_id in list(self._tasks): + if self._tasks[task_id].status == 'running': + self.kill(task_id) + + def drain_notifications(self) -> List[str]: + """Drain all pending notifications synchronously. Called from run_loop.""" + notifications = [] + while not self._notification_queue.empty(): + try: + notifications.append(self._notification_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return notifications + + def get_task(self, task_id: str) -> Optional[BackgroundTask]: + return self._tasks.get(task_id) + + def running_tasks(self) -> List[BackgroundTask]: + return [t for t in self._tasks.values() if t.status == 'running'] + + @staticmethod + def _format_notification(task: BackgroundTask) -> str: + result_line = f'\n{task.result}' if task.result else '' + error_line = f'\n{task.error}' if task.error else '' + duration = '' + if task.ended_at: + duration = f'\n{task.ended_at - task.started_at:.1f}' + return (f'\n' + f'{task.task_id}\n' + f'{task.task_type}\n' + f'{task.tool_name}\n' + f'{task.description}\n' + f'{task.status}' + f'{result_line}{error_line}{duration}\n' + f'') diff --git a/ms_agent/utils/thread_util.py b/ms_agent/utils/thread_util.py index 16e46eba9..380dd99b3 100644 --- a/ms_agent/utils/thread_util.py +++ b/ms_agent/utils/thread_util.py @@ -4,9 +4,9 @@ import weakref from concurrent.futures import ThreadPoolExecutor, as_completed from functools import wraps +from tqdm.auto import tqdm from ms_agent.utils.logger import get_logger -from tqdm.auto import tqdm logger = get_logger() @@ -102,7 +102,8 @@ def weakref_cb(_, q=self._work_queue): thread_name = '%s_%d' % (self._thread_name_prefix or self, num_threads) # Import internal helpers from stdlib to keep behavior consistent. - from concurrent.futures.thread import _worker, _threads_queues # type: ignore + from concurrent.futures.thread import ( # type: ignore + _threads_queues, _worker) t = threading.Thread( name=thread_name, diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index a6d0bc87d..91d9c6d42 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -5,21 +5,20 @@ import html import importlib import importlib.util +import json import os.path import re +import requests import subprocess import sys import time +import yaml from contextlib import contextmanager from io import BytesIO +from omegaconf import DictConfig, OmegaConf from pathlib import Path from typing import List, Optional, Tuple, Union -import json -import requests -import yaml -from omegaconf import DictConfig, OmegaConf - from .constants import DEFAULT_MEMORY_DIR from .logger import get_logger @@ -240,8 +239,8 @@ def read_history(output_dir: str, task: str): TypeError / AttributeError: If the deserialized JSON data lacks expected keys or structure for Message objects. """ - from ms_agent.llm import Message from ms_agent.config import Config + from ms_agent.llm import Message cache_dir = os.path.join(output_dir, DEFAULT_MEMORY_DIR) os.makedirs(cache_dir, exist_ok=True) config_file = os.path.join(cache_dir, f'{task}.yaml') diff --git a/ms_agent/utils/workspace_policy.py b/ms_agent/utils/workspace_policy.py new file mode 100644 index 000000000..c11edf4db --- /dev/null +++ b/ms_agent/utils/workspace_policy.py @@ -0,0 +1,207 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Workspace path policy: allow-roots (default output_dir) and optional deny globs.""" + +from __future__ import annotations + +import fnmatch +import os +import re +from pathlib import Path +from typing import Iterable, Sequence + + +class WorkspacePolicyError(ValueError): + """Raised when a path or command violates workspace policy.""" + + +class WorkspacePolicyKernel: + """Resolve user paths under allowed workspace roots; optional shell read-only rules.""" + + def __init__( + self, + output_dir: Path | str, + *, + extra_allow_roots: Sequence[str | Path] | None = None, + deny_globs: Sequence[str] | None = None, + shell_default_mode: str = 'workspace_write', + shell_network_enabled: bool = False, + max_command_chars: int = 8192, + ) -> None: + self._output = Path(output_dir).expanduser().resolve() + self._roots: list[Path] = [self._output] + if extra_allow_roots: + for r in extra_allow_roots: + p = Path(r).expanduser().resolve() + if p not in self._roots: + self._roots.append(p) + if deny_globs is None or len(tuple(deny_globs)) == 0: + self._deny_globs: tuple[str, ...] = ('**/.git/**', ) + else: + self._deny_globs = tuple(deny_globs) + self.shell_default_mode = shell_default_mode + self.shell_network_enabled = shell_network_enabled + self.max_command_chars = max_command_chars + + @property + def workspace_root(self) -> Path: + return self._output + + @property + def allow_roots(self) -> tuple[Path, ...]: + return tuple(self._roots) + + @property + def deny_globs(self) -> tuple[str, ...]: + return self._deny_globs + + def resolve_under_roots(self, user_path: str | Path) -> Path: + """Resolve *user_path* to an absolute path that must lie under one allow root.""" + raw = Path(user_path).expanduser() + if raw.is_absolute(): + resolved = raw.resolve() + else: + resolved = (self._output / raw).resolve() + for root in self._roots: + try: + resolved.relative_to(root) + break + except ValueError: + continue + else: + raise WorkspacePolicyError( + f'Path is outside allowed workspace roots: {resolved}') + if self._is_denied(resolved): + raise WorkspacePolicyError( + f'Path matches a deny_globs pattern: {resolved}') + return resolved + + def _is_denied(self, path: Path) -> bool: + if not self._deny_globs: + return False + rel = None + try: + rel = path.relative_to(self._output) + except ValueError: + rel = path + rel_s = rel.as_posix() + for pat in self._deny_globs: + if fnmatch.fnmatch(rel_s, pat) or fnmatch.fnmatch(path.name, pat): + return True + if fnmatch.fnmatch(str(path), pat): + return True + return False + + def path_is_allowed(self, path: Path) -> bool: + path = path.expanduser().resolve() + for root in self._roots: + try: + path.relative_to(root) + break + except ValueError: + continue + else: + return False + return not self._is_denied(path) + + def assert_shell_command_allowed(self, command: str) -> None: + """Length and mode-based checks before executing shell.""" + if not command or not command.strip(): + raise WorkspacePolicyError('Empty shell command') + if len(command) > self.max_command_chars: + raise WorkspacePolicyError( + f'Shell command exceeds max length ({self.max_command_chars})') + + mode = self.shell_default_mode + if mode == 'read_only': + if _shell_looks_mutating_or_network(command, allow_network=False): + raise WorkspacePolicyError( + 'Shell is in read_only mode: mutating or network commands are not allowed' + ) + elif mode == 'workspace_write': + if not self.shell_network_enabled and _shell_looks_network( + command): + raise WorkspacePolicyError( + 'Network commands are disabled for shell (enable tools.code_executor.shell.network_enabled)' + ) + # future: explicit 'network' mode could allow curl etc. + + +def _shell_looks_network(command: str) -> bool: + lowered = command.lower() + tokens = ( + 'curl ', + 'wget ', + 'ssh ', + 'scp ', + 'rsync ', + 'ftp ', + 'nc ', + 'netcat ', + 'pip install', + 'pip3 install', + 'npm install', + 'yarn add', + 'pnpm add', + ) + return any(t in lowered for t in tokens) + + +def _shell_looks_mutating_or_network(command: str, *, + allow_network: bool) -> bool: + if not allow_network and _shell_looks_network(command): + return True + # redirection that creates/overwrites files + if re.search(r'[>]{1,2}\s*[^\s]', command): + return True + if re.search(r'\b(rm|rmdir|mv|cp|chmod|chown|chgrp|mkdir|touch|tee)\b', + command): + return True + return False + + +def iter_files_under( + root: Path, + *, + deny_globs: Iterable[str] = (), + max_files: int = 100_000, +) -> Iterable[Path]: + """Yield files under *root* (depth-first), skipping directories matching deny globs.""" + deny = tuple(deny_globs) + count = 0 + root = root.resolve() + + def dir_skipped(dirpath: Path) -> bool: + try: + rel = dirpath.relative_to(root).as_posix() + except ValueError: + return True + for pat in deny: + if fnmatch.fnmatch(rel, pat) or fnmatch.fnmatch(rel + '/', pat): + return True + parts = rel.split('/') + for i in range(len(parts)): + sub = '/'.join(parts[:i + 1]) + if fnmatch.fnmatch(sub, pat.rstrip('/')) or fnmatch.fnmatch( + sub + '/', pat): + return True + return False + + for dirpath, dirnames, filenames in os.walk( + root, topdown=True, followlinks=False): + dp = Path(dirpath) + if dir_skipped(dp): + dirnames[:] = [] + continue + # prune skipped subdirs + keep: list[str] = [] + for d in dirnames: + child = dp / d + if dir_skipped(child): + continue + keep.append(d) + dirnames[:] = keep + for name in filenames: + count += 1 + if count > max_files: + return + yield dp / name diff --git a/ms_agent/workflow/base.py b/ms_agent/workflow/base.py index 9d484118c..e2f8b0cec 100644 --- a/ms_agent/workflow/base.py +++ b/ms_agent/workflow/base.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from abc import ABC, abstractmethod +from omegaconf import DictConfig from typing import Dict, Optional from ms_agent.config import Config -from omegaconf import DictConfig class Workflow(ABC): diff --git a/ms_agent/workflow/chain_workflow.py b/ms_agent/workflow/chain_workflow.py index 2b2b739e1..994c98066 100644 --- a/ms_agent/workflow/chain_workflow.py +++ b/ms_agent/workflow/chain_workflow.py @@ -1,10 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os +from omegaconf import DictConfig from ms_agent.agent.loader import AgentLoader from ms_agent.utils import get_logger from ms_agent.workflow.base import Workflow -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/workflow/dag_workflow.py b/ms_agent/workflow/dag_workflow.py index 1419ccf46..14ea1d671 100644 --- a/ms_agent/workflow/dag_workflow.py +++ b/ms_agent/workflow/dag_workflow.py @@ -1,12 +1,12 @@ # Copyright (c) Alibaba, Inc. import os from collections import defaultdict, deque +from omegaconf import DictConfig from typing import Any, Dict, List, Set from ms_agent.agent.loader import AgentLoader from ms_agent.utils import get_logger from ms_agent.workflow.base import Workflow -from omegaconf import DictConfig logger = get_logger() diff --git a/ms_agent/workflow/deep_research/research_utils.py b/ms_agent/workflow/deep_research/research_utils.py index 7d5ccdb69..1bfdbc70a 100644 --- a/ms_agent/workflow/deep_research/research_utils.py +++ b/ms_agent/workflow/deep_research/research_utils.py @@ -1,8 +1,7 @@ -from typing import Dict, List, Optional - from pydantic import BaseModel, Field from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn +from typing import Dict, List, Optional console = Console() diff --git a/ms_agent/workflow/deep_research/research_workflow.py b/ms_agent/workflow/deep_research/research_workflow.py index 03d64bf80..23ec6957b 100644 --- a/ms_agent/workflow/deep_research/research_workflow.py +++ b/ms_agent/workflow/deep_research/research_workflow.py @@ -1,11 +1,11 @@ # flake8: noqa # yapf: disable import copy +import json import os import re from typing import Any, Dict, List, Optional, Union -import json from ms_agent.llm.openai import OpenAIChat from ms_agent.utils import get_logger @@ -222,7 +222,8 @@ def generate_todo(self, **kwargs) -> None: def search(self, search_request: 'SearchRequest') -> str: from ms_agent.tools.search.exa.schema import dump_batch_search_results - from ms_agent.tools.search.search_base import SearchRequest, SearchResult + from ms_agent.tools.search.search_base import (SearchRequest, + SearchResult) if self._reuse: # Load existing search results if they exist if os.path.exists(self.workdir_structure['search']): @@ -349,7 +350,8 @@ def run(self, from ms_agent.rag.extraction_manager import extract_key_information from ms_agent.rag.schema import KeyInformation from ms_agent.tools.search.search_base import SearchResult - from ms_agent.tools.search.search_request import get_search_request_generator + from ms_agent.tools.search.search_request import \ + get_search_request_generator from ms_agent.utils.utils import remove_resource_info, text_hash special_resources: List = [] if urls_or_files: diff --git a/ms_agent/workflow/deep_research/research_workflow_beta.py b/ms_agent/workflow/deep_research/research_workflow_beta.py index 24816ae52..b490f098d 100644 --- a/ms_agent/workflow/deep_research/research_workflow_beta.py +++ b/ms_agent/workflow/deep_research/research_workflow_beta.py @@ -1,13 +1,14 @@ # yapf: disable import asyncio +import click import os import re import threading from concurrent.futures import ThreadPoolExecutor from datetime import datetime +from rich.prompt import Confirm, Prompt from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import click from ms_agent.llm.openai import OpenAIChat from ms_agent.rag.extraction_manager import extract_key_information from ms_agent.tools.search.exa.schema import dump_batch_search_results @@ -21,7 +22,6 @@ ResearchProgress, ResearchResult) from ms_agent.workflow.deep_research.research_workflow import ResearchWorkflow -from rich.prompt import Confirm, Prompt logger = get_logger() diff --git a/ms_agent/workflow/loader.py b/ms_agent/workflow/loader.py index 7e0589cfa..b11a2763c 100644 --- a/ms_agent/workflow/loader.py +++ b/ms_agent/workflow/loader.py @@ -1,8 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from omegaconf import DictConfig, OmegaConf from typing import Dict, Optional from ms_agent.config.config import Config -from omegaconf import DictConfig, OmegaConf class WorkflowLoader: diff --git a/projects/code_genesis/architect.yaml b/projects/code_genesis/architect.yaml index 988091875..79c65d535 100644 --- a/projects/code_genesis/architect.yaml +++ b/projects/code_genesis/architect.yaml @@ -37,7 +37,7 @@ prompt: Your Steps: 1. Read topic and user story to obtain the original requirements and user stories. 2. If topic contains documentation (e.g., usage of external frameworks, project PRDs, etc.), you must read them to understand user requirements. - * When reading external documents, prioritize using `read_abbreviation_file` to reduce token usage. + * When reading external documents, prioritize using `read_file` with `abbreviate: true` to reduce token usage. 3. Design the technology selection. If the user has specified technology choices, prioritize using the user's choices and supplement any missing parts. * [Historical Error]: Failure due to mixing different technical frameworks within the same language in different files. Your framework information should prevent this issue. - Example: Mixing CommonJS and ES6 in package.json and xx.js causing runtime failure; you should specify syntax rules in framework.txt. diff --git a/projects/code_genesis/coding.yaml b/projects/code_genesis/coding.yaml index 1508a5de1..cc99b728f 100644 --- a/projects/code_genesis/coding.yaml +++ b/projects/code_genesis/coding.yaml @@ -34,7 +34,7 @@ prompt: - `postcss.config.js` enables `tailwindcss` and `autoprefixer` 3. CRITICAL: Before reading ANY file: - * FIRST use `list_files` to check which files actually exist in the project + * FIRST use `file_system---glob` (e.g. pattern `**/*`, path `.`) to list paths under the project, or `code_executor---shell_executor` with a read-only `ls`/`find` command if you prefer the shell. * NEVER read files that do not appear in the output * NEVER attempt to read files with index >= yours (they don't exist yet) * NEVER guess or assume a file exists - always verify first @@ -42,13 +42,13 @@ prompt: 3.1 CRITICAL (no hallucination about files): * Do not fully trust `protocol.txt` content; you must verify it yourself by checking existing files and reading the exact source of truth. * Before you reference/cite ANY information that comes from a file (APIs, exports, config values, routes, CSS class names, build scripts, ports, etc.), you MUST: - 1) Confirm the file exists via `list_files`, AND + 1) Confirm the file exists via `file_system---glob` (or shell listing), AND 2) Read the relevant part of that exact file via `read_file`. * You are NOT allowed to infer file contents. * If the needed file is missing from the file list, do not reference it; either create it (if allowed by your index constraints) or implement the needed logic in your current file. 4. Use the `read_file` tool ONLY for existing dependency files: - * You can specify the `start_line` and `end_line` parameters to read partially and reduce token usage. + * You can specify the `offset` (1-based start line) and `limit` (number of lines) parameters on `read_file` to read partially and reduce token usage. * When writing code for RPC calls such as HTTP, you need to call the `url_search` tool to confirm protocol details. This tool can pass keywords to search URLs to find the API list you might use. Do not call non-existent API interfaces. 5. Write code, do not use `edit_file` to write non-existent files! @@ -103,7 +103,7 @@ prompt: 8. When fixing issues and updating files: Call the `edit_file` tool or `write_file` tool, after fixing issues there's no need to check by yourself, the lsp tool will check and report issues. - 9. You can use the shell_executor to debug problems: + 9. You can use `code_executor---shell_executor` to debug problems: Example: shell_executor(command='python -c "from module import MyClass; print(vars(MyClass))"') @@ -117,10 +117,11 @@ tools: mcp: false allow_read_all_files: true include: - - read_file - - edit_file - - list_files - - write_file + - read + - edit + - write + - grep + - glob edit_file_config: diff_model: morph-v3-fast api_key: diff --git a/projects/code_genesis/refine.yaml b/projects/code_genesis/refine.yaml index 2795dc654..8cea4e407 100644 --- a/projects/code_genesis/refine.yaml +++ b/projects/code_genesis/refine.yaml @@ -46,9 +46,9 @@ prompt: a. Write test cases to verify correctness 3. Use your tools effectively - * Use `read_file` to read the original file. When using `read_file`, specify `start_line` and `end_line` to reduce token usage + * Use `read_file` to read the original file. When using `read_file`, specify `offset` (1-based start line) and `limit` (line count) to reduce token usage * If an HTTP API is unclear, use `workflow/api_search` to locate the implementation - * Use `search_file_content` to search the project for keywords you care about + * To scan the project for keywords without a dedicated search tool, use `read_file` with `abbreviate: true` on candidate files, or read likely paths with `offset`/`limit` * Shell commands must not use system directories (except `/dev/null`) 4. When run successfully, you can use EdgeOne Pages MCP tools to deploy the project. Available tools from `edgeone-pages-mcp` include: @@ -96,8 +96,9 @@ tools: include: - read_file - write_file - - list_files - edit_file + - grep + - glob edit_file_config: diff_model: morph-v3-fast api_key: diff --git a/projects/code_genesis/workflow/api_search.py b/projects/code_genesis/workflow/api_search.py index 4e380d17a..ddc933475 100644 --- a/projects/code_genesis/workflow/api_search.py +++ b/projects/code_genesis/workflow/api_search.py @@ -1,9 +1,9 @@ +import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.constants import DEFAULT_INDEX_DIR diff --git a/projects/code_genesis/workflow/coding.py b/projects/code_genesis/workflow/coding.py index e5b696373..055e91418 100644 --- a/projects/code_genesis/workflow/coding.py +++ b/projects/code_genesis/workflow/coding.py @@ -1,14 +1,15 @@ import asyncio import dataclasses +import json import os import re import shutil from collections import OrderedDict from copy import deepcopy +from omegaconf import DictConfig from pathlib import Path from typing import List, Optional, Set -import json from ms_agent import LLMAgent from ms_agent.agent import CodeAgent from ms_agent.llm import Message @@ -19,7 +20,6 @@ DEFAULT_TAG) from ms_agent.utils.parser_utils import ImportInfo, parse_imports from ms_agent.utils.utils import extract_code_blocks, file_lock -from omegaconf import DictConfig logger = get_logger() @@ -163,9 +163,7 @@ def find_all_read_files(): for message in messages: if message.tool_calls: for tool_call in message.tool_calls: - if 'read_file' in tool_call[ - 'tool_name'] or 'read_abbreviation_file' in tool_call[ - 'tool_name']: + if 'read_file' in tool_call['tool_name']: arguments = tool_call['arguments'] if isinstance(arguments, str): try: diff --git a/projects/code_genesis/workflow/file_design.py b/projects/code_genesis/workflow/file_design.py index b33f977a5..216a0aa18 100644 --- a/projects/code_genesis/workflow/file_design.py +++ b/projects/code_genesis/workflow/file_design.py @@ -1,7 +1,7 @@ +import json import os from typing import List -import json from ms_agent import LLMAgent from ms_agent.llm import Message diff --git a/projects/code_genesis/workflow/file_order.py b/projects/code_genesis/workflow/file_order.py index 592a68025..0a2ec17eb 100644 --- a/projects/code_genesis/workflow/file_order.py +++ b/projects/code_genesis/workflow/file_order.py @@ -1,7 +1,7 @@ +import json import os from typing import List -import json from ms_agent import LLMAgent from ms_agent.llm import Message diff --git a/projects/code_genesis/workflow/refine.py b/projects/code_genesis/workflow/refine.py index 17c9ffba4..6e07e3359 100644 --- a/projects/code_genesis/workflow/refine.py +++ b/projects/code_genesis/workflow/refine.py @@ -1,15 +1,15 @@ +import json import os import sys +from coding import CodingAgent +from omegaconf import DictConfig from typing import List, OrderedDict -import json -from coding import CodingAgent from ms_agent import LLMAgent from ms_agent.llm import Message from ms_agent.memory.condenser.refine_condenser import RefineCondenser from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/deep_research/.env.example b/projects/deep_research/.env.example index 59ec0d03b..8f0552b68 100644 --- a/projects/deep_research/.env.example +++ b/projects/deep_research/.env.example @@ -1,5 +1,7 @@ EXA_API_KEY=xxx SERPAPI_API_KEY=xxx +# Tavily Search / Extract (optional; set when using engines: [tavily] or fetcher: tavily_extract) +TAVILY_API_KEY=tvly-xxx OPENAI_API_KEY=xxx OPENAI_BASE_URL=https://your-openai-compatible-endpoint/v1 diff --git a/projects/deep_research/v2/callbacks/quality_checker.py b/projects/deep_research/v2/callbacks/quality_checker.py index 36fadf902..bce1804fa 100644 --- a/projects/deep_research/v2/callbacks/quality_checker.py +++ b/projects/deep_research/v2/callbacks/quality_checker.py @@ -1,12 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json from abc import ABC, abstractmethod +from omegaconf import DictConfig, OmegaConf from typing import List, Optional -import json from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig, OmegaConf logger = get_logger() diff --git a/projects/deep_research/v2/callbacks/reporter_callback.py b/projects/deep_research/v2/callbacks/reporter_callback.py index 7bfb5bee3..87d4f58ea 100644 --- a/projects/deep_research/v2/callbacks/reporter_callback.py +++ b/projects/deep_research/v2/callbacks/reporter_callback.py @@ -1,19 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # yapf: disable +import json import os import re import shutil -from typing import Any, Dict, List, Optional, Set - -import json from callbacks.quality_checker import (ReportQualityChecker, build_quality_checkers) +from omegaconf import DictConfig +from typing import Any, Dict, List, Optional, Set + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_MEMORY_DIR -from omegaconf import DictConfig logger = get_logger() @@ -176,9 +176,9 @@ class ReporterCallback(Callback): '- 请严格遵守系统指令中的要求,不要遗漏、忽略任何合理的规则。\n' '- 审查要点包括事实准确性、逻辑一致性、用户核心问题的覆盖度、引用与论据的对齐关系、引用格式问题、内容完整性等等。' '修改须有明确依据(如事实冗余、逻辑混乱、证据不一致、格式出错等),不要为了"润色"而改动结构/质量良好的内容。\n' - '- 读取报告内容一次后形成判断,后续核查优先使用 search_file_content 或带 start_line / end_line 的 read_file,不要反复全量读取同一文件。' + '- 读取报告内容一次后形成判断,后续核查优先使用 read_file 的 offset/limit 或 abbreviate,不要反复全量读取同一文件。' '在读取文件前先检查对话历史中是否已包含该文件的内容,避免重复读取。\n' - '- 优先使用定点修改(search_file_content -> replace_file_contents / replace_file_lines),仅在必要时才读取全文。' + '- 优先使用定点修改(read_file 定位片段 -> edit_file 精确替换),仅在必要时才读取全文。' '仅在定点修改完全无法解决时使用 write_file,且**必须完整保留所有有价值的内容**,严禁使用占位符、省略标记、引用其他内容等方式替代正文。\n' '- 质量较高无需修改的部分直接跳过。如果[Reporter 工作总结]中无异常且审查确认全文质量良好,直接进入结论阶段即可。\n\n' '**需避免的常见错误:**\n' @@ -210,11 +210,11 @@ class ReporterCallback(Callback): 'Edits must have clear justification (e.g., factual redundancy, logical confusion, evidence inconsistency, ' 'formatting errors, etc.) — do not alter well-structured, high-quality content merely for "polishing."\n' '- Read the report content ONCE to form your assessment. For subsequent ' - 'checks, prefer `search_file_content` or `read_file` with `start_line`/`end_line`. ' + 'checks, prefer `read_file` with `offset`/`limit` or `abbreviate`. ' 'Do not re-read the entire file repeatedly. Check your conversation history before ' 'reading any file to avoid redundant reads.\n' - '- Prefer targeted fixes (`search_file_content` -> `replace_file_contents` / ' - '`replace_file_lines`); only read the full text when necessary. ' + '- Prefer targeted fixes (`read_file` to locate a span, then `edit_file` with exact ' + '`old_string`/`new_string`); only read the full text when necessary. ' 'Use `write_file` only when targeted fixes are completely insufficient, ' 'and you **must preserve ALL valuable content in full** — never use placeholders, ' 'ellipsis markers, or references to other content as substitutes for actual text.\n' @@ -231,8 +231,8 @@ class ReporterCallback(Callback): } _WORK_SUMMARY_LABEL = { - 'zh': '**[Reporter 工作总结]**', - 'en': '**[Reporter Work Summary]**', + 'zh': '**[Reporter 返回的 JSON 结果]**', + 'en': '**[Reporter\'s Returned JSON Result]**', } def __init__(self, config: DictConfig): diff --git a/projects/deep_research/v2/callbacks/researcher_callback.py b/projects/deep_research/v2/callbacks/researcher_callback.py index 4796e2151..1b162f83e 100644 --- a/projects/deep_research/v2/callbacks/researcher_callback.py +++ b/projects/deep_research/v2/callbacks/researcher_callback.py @@ -1,14 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +# yapf: disable import os -from typing import List, Optional - +import re +import shutil from callbacks.quality_checker import (ReportQualityChecker, build_quality_checkers) +from omegaconf import DictConfig, OmegaConf +from typing import List, Optional + from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback +from ms_agent.llm.openai_llm import OpenAI as OpenAILLM from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -20,14 +24,19 @@ class ResearcherCallback(Callback): a chain of quality checks before allowing the run to end: 1. **File existence**: has ``final_report.md`` been written to disk? - 2. **Quality checkers**: a configurable list of + 2. **Compression check**: is the report over-compressed vs + ``reports/draft.md``? + 3. **Quality checkers**: a configurable list of :class:`ReportQualityChecker` instances run in order; the first failure triggers a reflection prompt. - If any check fails, a reflection prompt is injected as a ``user`` - message, ``runtime.should_stop`` is flipped back to ``False``, and - the agent continues for one more iteration. A configurable retry - cap prevents infinite loops. + At task end (``on_task_end``), optionally: + + - Selects the best available report source (``final_report.md`` > + ``reports/report.md`` > ``reports/draft.md``) based on character-count + retention ratio, and promotes it to ``final_report.md``. + - Runs a format-cleanup agent to fix citation and reference formatting + issues (e.g., multiple reference sections, inconsistent numbering). YAML configuration (all optional, shown with defaults):: @@ -35,6 +44,17 @@ class ResearcherCallback(Callback): enabled: true max_retries: 2 report_filename: final_report.md + compression_check: + enabled: false + min_retention_ratio: 0.3 + report_selection: + enabled: false + min_retention_ratio: 0.3 + report_cleanup: + enabled: false + # model: ... # defaults to researcher llm.model + # openai_api_key: ... # falls back to llm.openai_api_key + # openai_base_url: ... # falls back to llm.openai_base_url quality_check: enabled: true model: qwen3.5-flash # lightweight audit model @@ -42,19 +62,35 @@ class ResearcherCallback(Callback): # openai_base_url: ... # falls back to llm.openai_base_url """ + REPORTS_DIR = 'reports' + DRAFT_FILENAME = 'draft.md' + REPORT_FILENAME = 'report.md' + DEFAULT_MIN_RETENTION_RATIO = 0.3 + _CLEANUP_OUTPUT_MIN_RATIO = 0.75 + _MAX_CLEANUP_CHARS = 200000 + _REFLECTION_TEMPLATES = { 'zh': { 'no_report': ('外部检查发现:输出目录中尚未生成 {filename},该文件原本应由 Reporter 子代理自动创建。\n' '请确认最终报告未交付的原因,并立即采取行动修复。\n' '请注意:不要使用占位符或缩略内容替代实际报告正文。'), + 'over_compressed': + ('外部检查发现:{report_name} 的内容量({report_chars} 字符)' + '仅为 {draft_name}({draft_chars} 字符)的 {ratio:.0%},有可能存在内容丢失风险,' + '请对报告内容进行检查并采取合理的行动。\n' + '**重要提醒**:{draft_name} 是由工具逐章组装的完整版本,理论上保留了最大的证据保真度。\n' + '- 如果你确认你对报告进行的修改是合理的,可以直接说明压缩内容的理由,无需再次修改或者重写。\n' + '- 如果你发现 {report_name} 相比 {draft_name} 确实存在不合理的压缩,' + '请通过重写/追加/续写等方式来修复这些问题。\n' + '请立即采取行动完成报告交付。'), 'low_quality': ('外部检查发现:{filename} 的内容存在质量问题——{reason}。\n' '请仔细确认上述质量问题是否属实、是否还有更多问题,并立即采取行动修复。\n' '**重要提醒**:如果质量问题属实,你必须按照以下原则进行修复:\n' - '1. 优先通过有针对性的局部修改完成修复。请使用 file_system---search_file_content 定位问题段落,' - '然后使用 file_system---replace_file_contents 和 file_system---replace_file_lines 进行针对性修复。' - '需要时可以使用 file_system---read_file (with start_line/end_line) 验证上下文是否一致。\n' + '1. 优先通过有针对性的局部修改完成修复。请使用 file_system---read_file(可用 offset/limit 或 abbreviate)定位问题段落,' + '然后使用 file_system---edit_file(old_string/new_string 必须与原文完全一致)进行针对性修复。' + '需要时可以使用 file_system---read_file(with offset/limit)验证上下文是否一致。\n' '2. 如果确认无法通过1完成修复,可以使用 file_system---write_file 全量重写报告,但请注意以下可能的质量违规:\n' '- 用省略号或缩略标记替代正文,如"(同之前,略)"、"此处省略"、"篇幅所限不再展开"、' '"……以下类似"、"内容已截断"、"Content truncated for brevity"等;\n' @@ -69,14 +105,25 @@ class ResearcherCallback(Callback): 'Please identify why the final report was not delivered and immediately take action to fix it.\n' 'Note: Do not use placeholders or abbreviated content in place of the actual report body.' ), + 'over_compressed': + ('External inspection found that {report_name} ({report_chars} chars) ' + 'is only {ratio:.0%} of {draft_name} ({draft_chars} chars), ' + 'indicating a risk of content loss. Please review the report content and take appropriate action.\n' + '**IMPORTANT**: {draft_name} is the tool-assembled complete version that theoretically ' + 'preserves maximum evidence fidelity.\n' + '- If you confirm that your modifications to the report are reasonable, you may simply ' + 'explain the rationale for the compression without further modifications or rewrites.\n' + '- If you find that {report_name} has indeed been unreasonably compressed ' + 'compared to {draft_name}, please rewrite/append/continue writing to repair these issues.\n' + 'Please take immediate action to complete report delivery.'), 'low_quality': ('External inspection found quality issues in {filename} — {reason}.\n' 'Please carefully verify whether these issues are valid and whether additional problems exist, ' 'then immediately take action to fix them.\n' '**IMPORTANT**: If the quality issues are confirmed, you must follow these principles to fix them:\n' - '1. PREFER targeted, localized fixes. Use file_system---search_file_content to locate the problematic sections, ' - 'then use file_system---replace_file_contents and file_system---replace_file_lines to apply precise corrections. ' - 'use file_system---read_file (with start_line/end_line) to verify surrounding context when needed.\n' + '1. PREFER targeted, localized fixes. Use file_system---read_file (with offset/limit or abbreviate) to locate the problematic sections, ' + 'then use file_system---edit_file (exact old_string/new_string) to apply precise corrections. ' + 'Use file_system---read_file (with offset/limit) to verify surrounding context when needed.\n' '2. If you confirm that targeted fixes alone cannot resolve the issues, you may use file_system---write_file ' 'to fully rewrite the report, but beware of the following quality violations:\n' '- Replacing body text with ellipsis or brevity markers, e.g., "(same as before, omitted)", ' @@ -89,6 +136,56 @@ class ResearcherCallback(Callback): }, } + _CLEANUP_SYSTEM_PROMPTS = { + 'zh': + ('你是一个研究报告格式清理专家。你的唯一任务是修复报告中的引用和参考文献格式问题。' + '你绝对不能修改报告的实质内容、论点、证据或分析。\n\n' + '只修复以下类型的问题:\n' + '1. 多个参考文献章节:将所有分散的参考文献列表合并为报告末尾的唯一一个统一参考文献章节。' + '移除完全重复的参考文献条目,按首次出现顺序重新编号。\n' + '2. 引用标记不一致:确保正文中所有引用标记使用统一格式(如 [1], [2]),修复格式错误的引用标记。\n' + '3. 失效引用:修复引用了不存在条目的标记,或处理从未被引用的参考文献条目。\n' + '4. 参考文献编号:确保参考文献按照在正文中首次出现的顺序从 [1] 开始连续编号。\n' + '5. 参考文献格式:确保每条参考文献遵循一致的格式。\n\n' + '关键规则:\n' + '- 不得修改、增加、删除或改写任何实质性内容。\n' + '- 不得改变标题、章节结构或组织方式(合并参考文献章节除外)。\n' + '- 不得添加新的引用或删除已引用的内容。\n' + '- 如果报告没有格式问题,则原样返回。\n' + '- 返回修复后的完整报告,而不仅仅是修改的部分。\n' + '- 不要使用 markdown 代码块包裹输出。'), + 'en': + ('You are a research report formatting specialist. Your ONLY job is to fix citation ' + 'and reference formatting issues in the report. You must NOT modify the substantive ' + 'content, arguments, evidence, or analysis in any way.\n\n' + 'Fix ONLY the following types of issues:\n' + '1. Multiple reference sections: Merge all scattered reference/bibliography lists into ' + 'a single unified reference section at the very end of the report. Remove exact duplicate ' + 'entries and renumber sequentially by order of first appearance.\n' + '2. Citation marker inconsistencies: Ensure all in-text citation markers use a consistent ' + 'format (e.g., [1], [2]) throughout the report. Fix any malformed citation markers.\n' + '3. Orphaned citations: Fix citation markers that reference non-existent entries, or handle ' + 'reference entries that are never cited in the text.\n' + '4. Reference numbering: Ensure references are numbered sequentially starting from [1], ' + 'in order of first appearance in the text.\n' + '5. Reference entry formatting: Ensure each reference entry follows a consistent format.\n\n' + 'CRITICAL RULES:\n' + '- Do NOT modify, add, remove, or rephrase any substantive content.\n' + '- Do NOT change headings, section structure, or organization (except merging reference sections).\n' + '- Do NOT add new citations or remove existing cited content.\n' + '- If the report has no formatting issues, return it unchanged.\n' + '- Return the COMPLETE report with fixes applied, not just the changed parts.\n' + '- Do NOT wrap the output in markdown code blocks.'), + } + + _CLEANUP_USER_TEMPLATES = { + 'zh': ('请修复以下研究报告中的引用和参考文献格式问题:\n\n' + '---报告开始---\n{report}\n---报告结束---'), + 'en': ('Please fix the citation and reference formatting issues ' + 'in the following research report:\n\n' + '---BEGIN REPORT---\n{report}\n---END REPORT---'), + } + def __init__(self, config: DictConfig): super().__init__(config) self.output_dir: str = getattr(config, 'output_dir', './output') @@ -105,6 +202,61 @@ def __init__(self, config: DictConfig): self.report_filename = str( getattr(refl_cfg, 'report_filename', self.report_filename)) + # --- Compression check config --- + comp_cfg = ( + getattr(refl_cfg, 'compression_check', None) if refl_cfg else None) + self.compression_check_enabled: bool = False + self.min_retention_ratio: float = self.DEFAULT_MIN_RETENTION_RATIO + if comp_cfg is not None: + self.compression_check_enabled = bool( + getattr(comp_cfg, 'enabled', False)) + self.min_retention_ratio = float( + getattr(comp_cfg, 'min_retention_ratio', + self.DEFAULT_MIN_RETENTION_RATIO)) + + # --- Report selection config (on_task_end) --- + sel_cfg = ( + getattr(refl_cfg, 'report_selection', None) if refl_cfg else None) + self.report_selection_enabled: bool = False + self._selection_min_ratio: float = self.DEFAULT_MIN_RETENTION_RATIO + if sel_cfg is not None: + self.report_selection_enabled = bool( + getattr(sel_cfg, 'enabled', False)) + self._selection_min_ratio = float( + getattr(sel_cfg, 'min_retention_ratio', + self.DEFAULT_MIN_RETENTION_RATIO)) + + # --- Format cleanup agent config (on_task_end) --- + cleanup_cfg = ( + getattr(refl_cfg, 'report_cleanup', None) if refl_cfg else None) + self.report_cleanup_enabled: bool = False + self._cleanup_model: Optional[str] = None + self._cleanup_api_key: Optional[str] = None + self._cleanup_base_url: Optional[str] = None + self._cleanup_generation_config: Optional[dict] = None + if cleanup_cfg is not None: + self.report_cleanup_enabled = bool( + getattr(cleanup_cfg, 'enabled', False)) + self._cleanup_model = getattr(cleanup_cfg, 'model', None) + self._cleanup_api_key = getattr(cleanup_cfg, 'openai_api_key', + None) + self._cleanup_base_url = getattr(cleanup_cfg, 'openai_base_url', + None) + gen_cfg = getattr(cleanup_cfg, 'generation_config', None) + if gen_cfg is not None: + self._cleanup_generation_config = ( + OmegaConf.to_container(gen_cfg, resolve=True) + if isinstance(gen_cfg, DictConfig) else dict(gen_cfg)) + + # --- Derived paths --- + self._reports_dir: str = self.REPORTS_DIR + self._draft_path: str = os.path.join(self.output_dir, + self._reports_dir, + self.DRAFT_FILENAME) + self._inner_report_path: str = os.path.join(self.output_dir, + self._reports_dir, + self.REPORT_FILENAME) + self._retries_used: int = 0 self._checkers: List[ReportQualityChecker] = build_quality_checkers( config) @@ -135,7 +287,166 @@ def _get_template(self, key: str) -> str: def _marker_path(self) -> str: return os.path.join(self.output_dir, self.TASK_FINISHED_MARKER) + def _select_best_report(self) -> Optional[str]: + """Return the path to the best available report, based on char-count + retention ratio against ``reports/draft.md``. + + Candidates (in preference order): + 1. ``final_report.md`` + 2. ``reports/report.md`` + 3. ``reports/draft.md`` + + A candidate is accepted if it has >= ``_selection_min_ratio`` chars + relative to the draft. Falls back to ``draft.md`` if all others + are over-compressed. + """ + final_path = self._report_path + candidates = [ + (final_path, self.report_filename), + (self._inner_report_path, + os.path.join(self._reports_dir, self.REPORT_FILENAME)), + ] + + has_draft = os.path.isfile(self._draft_path) + if not has_draft: + for path, _ in candidates: + if os.path.isfile(path): + return path + return None + + try: + draft_chars = self._read_char_count(self._draft_path) + except OSError: + draft_chars = 0 + + if draft_chars <= 0: + for path, _ in candidates: + if os.path.isfile(path): + return path + return self._draft_path + + for path, name in candidates: + if not os.path.isfile(path): + continue + try: + chars = self._read_char_count(path) + except OSError: + continue + ratio = chars / draft_chars + if ratio >= self._selection_min_ratio: + return path + logger.warning( + f'ResearcherCallback: {name} ({chars} chars) is only ' + f'{ratio:.0%} of draft ({draft_chars} chars), ' + f'trying next candidate.') + + logger.warning('ResearcherCallback: all candidates over-compressed, ' + 'falling back to draft.md.') + return self._draft_path + + def _run_format_cleanup(self, report_path: str) -> bool: + """Run format cleanup on the report to fix citation/reference issues. + + Returns True if cleanup was applied successfully. + """ + try: + with open(report_path, 'r', encoding='utf-8') as f: + content = f.read() + except Exception as exc: + logger.warning(f'ResearcherCallback: failed to read report for ' + f'format cleanup: {exc}') + return False + + if not content.strip(): + logger.info( + 'ResearcherCallback: report is empty, skipping cleanup.') + return False + + if len(content) > self._MAX_CLEANUP_CHARS: + logger.warning( + f'ResearcherCallback: report too long for format cleanup ' + f'({len(content)} chars > {self._MAX_CLEANUP_CHARS}), ' + f'skipping.') + return False + + model = ( + self._cleanup_model or getattr(self.config.llm, 'model', None)) + api_key = ( + self._cleanup_api_key + or getattr(self.config.llm, 'openai_api_key', None)) + base_url = ( + self._cleanup_base_url + or getattr(self.config.llm, 'openai_base_url', None)) + + if not model: + logger.warning( + 'ResearcherCallback: no model configured for format cleanup.') + return False + + gen_cfg = self._cleanup_generation_config or {} + llm_config = OmegaConf.create({ + 'llm': { + 'model': model, + 'openai_api_key': api_key, + 'openai_base_url': base_url, + }, + 'generation_config': gen_cfg, + }) + + try: + client = OpenAILLM(llm_config) + except Exception as exc: + logger.warning(f'ResearcherCallback: failed to create LLM client ' + f'for format cleanup: {exc}') + return False + + sys_prompt = self._CLEANUP_SYSTEM_PROMPTS.get( + self.lang, self._CLEANUP_SYSTEM_PROMPTS['en']) + usr_template = self._CLEANUP_USER_TEMPLATES.get( + self.lang, self._CLEANUP_USER_TEMPLATES['en']) + + try: + response = client.generate(messages=[ + Message(role='system', content=sys_prompt), + Message( + role='user', content=usr_template.format(report=content)), + ]) + cleaned = (response.content or '').strip() + except Exception as exc: + logger.warning( + f'ResearcherCallback: format cleanup LLM call failed: {exc}') + return False + + if not cleaned: + logger.warning( + 'ResearcherCallback: format cleanup returned empty output.') + return False + + # Strip markdown code-block wrapper if present + cleaned = re.sub(r'^```\w*\n', '', cleaned) + cleaned = re.sub(r'\n```\s*$', '', cleaned) + + # Guard against truncated output + if len(cleaned) < len(content) * self._CLEANUP_OUTPUT_MIN_RATIO: + logger.warning( + f'ResearcherCallback: format cleanup output appears ' + f'truncated ({len(cleaned)} vs {len(content)} chars), ' + f'keeping original.') + return False + + try: + with open(report_path, 'w', encoding='utf-8') as f: + f.write(cleaned) + logger.info(f'ResearcherCallback: format cleanup applied ' + f'({len(content)} -> {len(cleaned)} chars).') + return True + except Exception as exc: + logger.warning( + f'ResearcherCallback: failed to write cleaned report: {exc}') + return False + async def on_task_end(self, runtime: Runtime, messages: List[Message]): + # --- Step 1: Write task-finished marker --- try: os.makedirs(self.output_dir, exist_ok=True) with open(self._marker_path, 'w') as f: @@ -147,6 +458,47 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): logger.warning( f'ResearcherCallback: failed to write marker: {exc}') + # --- Step 2: Best report selection --- + if self.report_selection_enabled: + best_source = self._select_best_report() + if best_source and best_source != self._report_path: + try: + os.makedirs( + os.path.dirname(self._report_path), exist_ok=True) + shutil.copy2(best_source, self._report_path) + source_name = os.path.relpath(best_source, self.output_dir) + logger.info( + f'ResearcherCallback: promoted {source_name} -> ' + f'{self.report_filename}') + except Exception as exc: + logger.warning( + f'ResearcherCallback: failed to promote report: ' + f'{exc}') + elif best_source: + logger.info( + f'ResearcherCallback: {self.report_filename} is already ' + f'the best candidate, no promotion needed.') + else: + logger.warning('ResearcherCallback: no report file found for ' + 'best-report selection.') + + # --- Step 3: Format cleanup agent --- + if self.report_cleanup_enabled: + if os.path.isfile(self._report_path): + logger.info( + 'ResearcherCallback: running format cleanup agent on ' + f'{self.report_filename}...') + self._run_format_cleanup(self._report_path) + else: + logger.warning( + f'ResearcherCallback: {self.report_filename} not found, ' + f'skipping format cleanup.') + + @staticmethod + def _read_char_count(path: str) -> int: + with open(path, 'r', encoding='utf-8') as f: + return len(f.read()) + async def after_tool_call(self, runtime: Runtime, messages: List[Message]): if not self.enabled: return @@ -169,7 +521,36 @@ async def after_tool_call(self, runtime: Runtime, messages: List[Message]): self._retries_used += 1 return - # --- Check 2: quality checker chain --- + # --- Check 2: compression check vs reports/draft.md --- + if self.compression_check_enabled and os.path.isfile(self._draft_path): + try: + report_chars = self._read_char_count(self._report_path) + draft_chars = self._read_char_count(self._draft_path) + if draft_chars > 0: + ratio = report_chars / draft_chars + if ratio < self.min_retention_ratio: + draft_rel = os.path.join(self._reports_dir, + self.DRAFT_FILENAME) + logger.warning( + f'ResearcherCallback: {self.report_filename} ' + f'({report_chars} chars) is only {ratio:.0%} of ' + f'{draft_rel} ({draft_chars} chars), ' + 'injecting over-compression prompt.') + prompt = self._get_template('over_compressed').format( + report_name=self.report_filename, + report_chars=report_chars, + draft_name=draft_rel, + draft_chars=draft_chars, + ratio=ratio) + messages.append(Message(role='user', content=prompt)) + runtime.should_stop = False + self._retries_used += 1 + return + except OSError as exc: + logger.warning(f'ResearcherCallback: failed to read files for ' + f'compression check: {exc}') + + # --- Check 3: quality checker chain --- if not self._checkers: logger.info('ResearcherCallback: no quality checkers configured, ' 'skipping quality gate.') diff --git a/projects/deep_research/v2/callbacks/searcher_callback.py b/projects/deep_research/v2/callbacks/searcher_callback.py index e48d35880..b93a8fd38 100644 --- a/projects/deep_research/v2/callbacks/searcher_callback.py +++ b/projects/deep_research/v2/callbacks/searcher_callback.py @@ -1,19 +1,56 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import uuid +from omegaconf import DictConfig from typing import Any, List, Optional -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() +def _parse_search_result_json(text: str) -> Optional[Any]: + """Parse searcher final JSON from raw assistant text. + + Accepts plain JSON, fenced ```json blocks, or a JSON object embedded in + surrounding prose (first balanced object via :func:`json.JSONDecoder.raw_decode`). + """ + if text is None: + return None + if not isinstance(text, str): + return None + text = text.strip() + if not text: + return None + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + pass + m = re.search( + r'```(?:json)?\s*\r?\n(.*?)```', text, flags=re.DOTALL | re.IGNORECASE) + if m: + block = m.group(1).strip() + if block: + try: + return json.loads(block) + except (json.JSONDecodeError, TypeError): + pass + dec = json.JSONDecoder() + for i, ch in enumerate(text): + if ch not in '{[': + continue + try: + return dec.raw_decode(text[i:])[0] + except json.JSONDecodeError: + continue + return None + + class SearcherCallback(Callback): """ Callback for Searcher agent. @@ -221,33 +258,48 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): try: # Prefer JSON file if possible; fallback to markdown otherwise. - if isinstance(content, str): - parsed_json = json.loads(content) - else: + if isinstance(content, (dict, list)): parsed_json = content - try: - with open(json_path, 'x', encoding='utf-8') as f: - json.dump( - parsed_json, f, ensure_ascii=False, indent=2) - logger.info( - f'Searcher: Search result saved to {json_path}') - except FileExistsError: - logger.info( - f'Search result already exists at {json_path}') - except (json.JSONDecodeError, TypeError): - logger.warning( - 'Failed to parse search result as JSON, saving as markdown' - ) - text = content if isinstance(content, - str) else str(content) - try: - with open(md_path, 'x', encoding='utf-8') as f: - f.write(text) - logger.info( - f'Searcher: Search result saved to {md_path}') - except FileExistsError: - logger.info( - f'Search result already exists at {md_path}') + elif isinstance(content, str): + try: + parsed_json = json.loads(content) + except (json.JSONDecodeError, TypeError): + parsed_json = _parse_search_result_json(content) + if parsed_json is not None: + logger.info( + 'Searcher: parsed JSON from fenced or embedded payload' + ) + else: + parsed_json = _parse_search_result_json(str(content)) + + if parsed_json is not None: + try: + with open(json_path, 'x', encoding='utf-8') as f: + json.dump( + parsed_json, + f, + ensure_ascii=False, + indent=2) + logger.info( + f'Searcher: Search result saved to {json_path}' + ) + except FileExistsError: + logger.info( + f'Search result already exists at {json_path}') + else: + logger.warning( + 'Failed to parse search result as JSON, saving as markdown' + ) + text = content if isinstance(content, + str) else str(content) + try: + with open(md_path, 'x', encoding='utf-8') as f: + f.write(text) + logger.info( + f'Searcher: Search result saved to {md_path}') + except FileExistsError: + logger.info( + f'Search result already exists at {md_path}') except Exception as e: logger.warning( f'Unexpected error when saving search result: {e}') diff --git a/projects/deep_research/v2/eval/dr_bench_runner.py b/projects/deep_research/v2/eval/dr_bench_runner.py index 1917564bf..add3634ec 100644 --- a/projects/deep_research/v2/eval/dr_bench_runner.py +++ b/projects/deep_research/v2/eval/dr_bench_runner.py @@ -17,7 +17,9 @@ """ from __future__ import annotations + import argparse +import json import os import subprocess import sys @@ -28,8 +30,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple -import json - try: # Auto-load environment variables from a nearby `.env` (if present). from dotenv import find_dotenv, load_dotenv diff --git a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt index fc276252a..19faef800 100644 --- a/projects/deep_research/v2/prompts/reporter/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/reporter/en/gpt5.txt @@ -55,7 +55,7 @@ Stopping conditions (stop if any one is satisfied): - Based on the reflection results and recorded conflicts, you MUST rewrite the final markdown report content and save it to reports/report.md. - The final markdown report must preserve the information density and structural quality of the draft — never replace substantive content with ellipsis/brevity markers (e.g., "omitted here", "content truncated"), pointers to external files (e.g., "details are in chapter_2.md"), or hollow reference-only placeholders (e.g., "see [1]"). - The format, style, and other aspects of the report must **follow the specifications required by the user's input and the "Default Report Style" section**. The report must include citations to reference sources. - - **Writing strategy to minimize information loss**: Prioritize writing the full report in a single file_system---write_file call. Switch to an incremental strategy if your write attempt gets truncated by the output limit: initialize the file with file_system---write_file containing as much content as possible; then sequentially append the remainder using file_system---replace_file_lines with start_line=-1 until the report is complete, utilizing as few calls as possible. + - **Writing strategy to minimize information loss**: Prioritize writing the full report in a single file_system---write_file call. Switch to an incremental strategy if your write attempt gets truncated by the output limit: initialize the file with file_system---write_file containing as much content as possible; then read the end of reports/report.md and repeatedly use file_system---edit_file to replace a unique trailing anchor with anchor+next_chunk (append by editing the tail), re-reading the tail before each append until the report is complete, using as few calls as practical. - After delivering the final report, return a work summary in JSON format in the conversation: - The Execution_Summary field must include the report generation status, evidence coverage, a summary of conflicts, and any other information that should be communicated to the user. - The Artifacts field must include the paths to intermediate file artifacts. diff --git a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt index 8891d8313..c79cd7344 100644 --- a/projects/deep_research/v2/prompts/researcher/en/gpt5.txt +++ b/projects/deep_research/v2/prompts/researcher/en/gpt5.txt @@ -15,7 +15,7 @@ Action protocol: Before outputting the final result, every iteration MUST invoke - When the research can only move forward by conducting synthesis based on the collected materials—such as framework design, cross-validation, scenario analysis, data analysis, etc.—you MUST proactively complete these tasks using the available tools. - Draft, review, deliver: - When research is sufficient, delegate to the Reporter sub-agent (i.e., agent_tools---reporter_tool) to generate the research report. The Reporter will automatically deliver the complete report as final_report.md. - - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---search_file_content to locate problems and file_system---replace_file_contents to fix them. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter’s output preserves maximum evidence fidelity. + - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---read_file (with `offset`/`limit` or `abbreviate` as needed) to locate the passage, then file_system---edit_file with an exact `old_string`/`new_string` pair to fix it. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter’s output preserves maximum evidence fidelity. # Reference Workflow The following is a proven workflow that works well for most research tasks. @@ -50,19 +50,21 @@ Stopping conditions (stop if you are confident to proceed to the next phase): ## Phase 3: Report Generation - Invoke the Reporter sub-agent to generate the report. Provide the Reporter sub-agent with the complete report topic, target audience, background, task description, writing requirements, section constraints, and any other necessary information. - - Note: do not impose a word-count requirement on the Reporter sub-agent unless the user explicitly requests it; DO NOT ask the Reporter sub-agent to include the Execution Summary (执行摘要) as a separate section in the report. -- The Reporter will deliver the complete report as final_report.md. After the Reporter returns, you MUST review the report for quality and accuracy: + - Do not impose a word-count requirement on the Reporter sub-agent unless the user explicitly requests it; DO NOT ask the Reporter sub-agent to include the Execution Summary (执行摘要) as a separate section in the report. + - The Reporter sub-agent writes the final report to reports/report.md and returns only the execution summary and artifact file paths. The system automatically copies reports/report.md to final_report.md. Therefore, final_report.md may not be listed in the Reporter's Artifacts field. Do not ask the Reporter to create final_report.md directly. Review and edit final_report.md. +- After the Reporter returns, you MUST review the report for quality and accuracy: + - **Read once, then act.** Read the full report content once to form your assessment. For subsequent checks, use file_system---search_file_content or file_system---read_file with start_line/end_line — do NOT re-read the entire report file repeatedly. Always check your conversation history before re-reading a file that may already be in context. - **Verify first.** Before editing, spot-check factual accuracy, logical consistency, coverage of the user's core questions, and citation–claim alignment against the collected evidence. - The report MUST comply with the "Quality Constraints" and "Default Report Style" sections. Execution Summary (执行摘要) MUST NOT appear as a chapter in the report body. - **Edit with justification.** Every substantive change (compression, deletion, restructuring, format conversion) must be driven by a concrete problem — such as factual redundancy, logical disorganization, evidence inconsistency, or style/quality violations. Well-structured content with reasonable depth and detail must be preserved as-is, including its structure, granularity, and length. - **Do not over-edit.** Do not convert flowing paragraphs into bullet-point lists, flatten detailed subsections into one-line summaries, or replace evidence-backed analysis with high-level abstractions — unless the original format genuinely hinders readability or violates the report style. - If the report passes your review without issues: proceed directly to your conclusion. Do NOT rewrite it "for polish." - If issues are found, **strongly prefer targeted corrections** over full rewrites: - - **Standard workflow**: use file_system---search_file_content to locate the problem, then file_system---replace_file_contents to fix it. This is the safest and most precise approach. - - Precision reminder: Punctuation mismatches (e.g., Chinese `、` vs English `,`; full-width vs half-width characters), whitespace differences, or line-break variations usually cause the replacement to fail. - - Parallel editing: for multiple independent fixes in the same file, use file_system---search_file_content and file_system---replace_file_contents in parallel when `source` spans do not overlap. However, NEVER call file_system---replace_file_lines in parallel on the same file — line numbers shift after each call. - - **Deleting or replacing line ranges**: use file_system---replace_file_lines with start_line/end_line to delete or replace a block of lines (e.g., removing an entire section). Use file_system---search_file_content first to locate the line numbers (start line and end line). - - **Inspect before editing**: use file_system---read_file (with start_line/end_line) to verify surrounding context when needed. + - **Standard workflow**: use file_system---read_file to load the relevant region (use `offset`/`limit` for large files, or `abbreviate` for a quick structural pass), then file_system---edit_file with an exact `old_string`/`new_string` match. You must read a file before writing to it; re-read if the file may have changed. + - Precision reminder: Punctuation mismatches (e.g., Chinese `、` vs English `,`; full-width vs half-width characters), whitespace differences, or line-break variations usually cause the replacement to fail — copy the span verbatim from read_file output. + - Apply multiple edits to the same file **sequentially** (re-read between edits if needed) so anchors stay valid; do not parallelize conflicting edits on one file. + - **Deleting or replacing a block**: read the section, then file_system---edit_file with `old_string` set to the exact contiguous block to remove or replace (include enough surrounding lines to make `old_string` unique). + - **Inspect before editing**: use file_system---read_file (with `offset`/`limit`) to verify surrounding context when needed. - **Last resort only**: file_system---write_file overwrites the entire file — use it only when targeted tools cannot address the issue (e.g., extensive structural reorganization). You must reproduce ALL content valuable to the user. - WARNING: Full report rewrites may carry high risk of content loss. Do not over-compress the report. Do not replace any content with placeholders such as "Content truncated for brevity." or "This section is stored in xxx file." - Finally show your conclusions for the entire task in the conversation. @@ -77,12 +79,12 @@ Stopping conditions (stop if you are confident to proceed to the next phase): # Tool Invocation Protocol - You MUST use the tools under the todo_list server to create, update, and read the TODO list. You MUST NOT use any other tools or services to maintain the TODO list. - You MUST use the tools under the agent_tools server to invoke the Searcher and Reporter sub-agents. You are not allowed to invoke non-existent sub-agent tools, and you MUST carefully follow the input requirements of those tools. -- When context is unclear (e.g., the Searcher sub-agent’s output appears to have lost details, or the Reporter sub-agent’s report has issues), you should read, filter, and load evidence using the evidence_store server, ensuring you have sufficient confidence before proceeding to the next step. +- When context is unclear (e.g., the Searcher sub-agent's output appears to have lost details, or the Reporter sub-agent's report has issues), you should read, filter, and load evidence using the evidence_store server, ensuring you have sufficient confidence before proceeding to the next step. - You are encouraged to invoke multiple tools in parallel when tasks are independent (e.g., retrieving unrelated information or performing separate operations). - For file-level operations, keep using relative paths. # Quality Constraints -- NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by the Searcher sub-agent’s research conclusions and stored evidence. +- NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by the Searcher sub-agent's research conclusions and stored evidence. - Clearly track time constraints and the current date. If the knowledge you intend to apply may be outdated, do not trust your memory; query via tools instead. - Strictly control scope: if the user asks for X, do not drift to Y. - Citation integrity in the final report: diff --git a/projects/deep_research/v2/prompts/researcher/en/thinking.txt b/projects/deep_research/v2/prompts/researcher/en/thinking.txt new file mode 100644 index 000000000..38c211a87 --- /dev/null +++ b/projects/deep_research/v2/prompts/researcher/en/thinking.txt @@ -0,0 +1,110 @@ +You are a highly capable, thoughtful, and precise research assistant. Your job is to plan and manage the end-to-end deep research workflow, delegate retrieval and writing tasks to sub-agents/tools, synthesize evidence into decisions, and polish the final report before delivery. +You have everything you need to resolve the task. I want you to fully solve this autonomously before coming back to me. +Time reminder: The current date is , and the current time is . +Language reminder: If you can infer the language from the user's query, make sure to keep this in mind when generating the report. +Research iterations: This refers to the number of loops in the research & analysis phase (excluding the report-generation phase). If the user does not specify the maximum number of research iterations, the default maximum is 6 iterations. You MUST complete the task within the maximum number of iterations. +Action protocol: Before outputting the final result, every iteration MUST invoke at least one tool. Periodically assess progress to determine the most effective next step. Keep your conversation output focused on key decisions and conclusions. Make each tool call purposeful and avoid redundancy — every tool result permanently occupies context capacity. + +# Primary Responsibilities +- Plan & orchestrate: + - Determine whether the request starts a new task or continues an unfinished one. If continuing, first assess the current completion status and recover relevant context; then convert the user's request into an executable research plan, store it as a TODO list (plan.json), and perform self-reflection by additionally generating a verification checklist (checklist.yaml). + - Based on task difficulty and user intent, orchestrate the available sub-agents and tools, and control the handling logic for different scenarios (short answer vs. professional report vs. casual conversation; default to a professional report unless the user asks otherwise). +- Retrieve evidence: + - When evidence is insufficient, delegate tasks to the Searcher sub-agent (i.e., agent_tools---searcher_tool) to perform an iterative research loop (when concurrency is allowed, 2–4 sub-agents can be invoked in parallel; prioritize parallel invocation when tasks are parallelizable). +- Analyze & synthesize: + - When the research can only move forward by conducting synthesis based on the collected materials—such as framework design, cross-validation, scenario analysis, data analysis, etc.—you MUST proactively complete these tasks using the available tools. +- Draft, review, deliver: + - When research is sufficient, delegate to the Reporter sub-agent (i.e., agent_tools---reporter_tool) to generate the research report. The Reporter will automatically deliver the complete report as final_report.md. + - Then you MUST review the report for quality and accuracy. If issues are found, apply **targeted corrections** using file_system---search_file_content to locate problems and file_system---replace_file_contents to fix them. Do NOT rewrite the entire report unless you are strongly sure it is necessary — the Reporter's output preserves maximum evidence fidelity. + +# Reference Workflow +The following is a proven workflow that works well for most research tasks. +You are free to adapt, reorder, or skip steps based on the complexity and requirements of the current task — but the general approach has been validated across many scenarios. + +## Phase 1: Task Planning +- Deeply understand the user's intent: analyze the user's conversational goal, background needs, and expected deliverables; proactively infer whether to start from scratch or continue an unfinished task. +- If resuming from an unfinished task, start by checking the current completion status using todo_list---todo_read and other available tools. +- Develop a manageable plan based on user's needs and task progress. Use todo_list---todo_write and file_system---write_file to generate the TODO list and the corresponding verification checklist checklist.yaml, respectively. + - The TODO list must cover all subtasks that need to be completed. It is used to clearly communicate your full plan to the user. You do not need to explicitly state which tools you will use; simply provide the tasks themselves; but you must ensure that every task can be completed using the existing tools. + - Tasks in the TODO list must be explicit, clear, and focused on solving the core problem. Each task should contain no more than three core questions to answer. + - Tasks in the TODO list should be assigned reasonable priorities: high for tasks directly answering the user's core questions, medium for supporting context or secondary dimensions, low for nice-to-have extensions. High-priority tasks should be executed first, while medium- and low-priority tasks should be performed only if the iteration budget allows. +- Do self-reflection. If you find issues with the current TODO list, fix them; otherwise, you may skip this step. + - If necessary, you can invoke the Searcher sub-agent at most once for concept clarification; + - If you find issues in the TODO list, you must revise it via the todo_list---todo_write tool. + +## Phase 2: Research & Analysis +Repeat the following steps until a stopping condition is met: +- Based on the execution status of tasks in the current TODO list, select appropriate actions: + - For tasks that require evidence retrieval, delegate them to the Searcher sub-agent. Make sure to provide detailed and clear task instructions; + - For tasks that require interim syntheses, decisions/trade-offs, frameworks/mappings, uncertainty tracking, justified recommendations, or structural diagrams (preferably Mermaid syntax), use evidence_store---write_analysis to record these intermediate analyses, and include based_on_note_ids when possible; + - For tasks that require data analysis or chart generation, use code_executor---notebook_executor to solve them. Try to finish in as few rounds as possible. When writing code, use relative file paths—the executor's working directory is the output root. Store key computed results via evidence_store---write_analysis. +- After completing the above actions, reflect and update the TODO list: + - Summarize interim findings; explicitly identify the evidence that has already been collected and maintained; identify conflicts and evidence gaps. + - Update the task statuses in the TODO list ('pending'/'in_progress'/'completed'/'cancelled') as soon as their status changes. + - If you identify issues in the plan and decide to revise it, update the TODO list. +Stopping conditions (stop if you are confident to proceed to the next phase): +- All subtasks for the research & analysis phase in the TODO list have been completed; or +- All the core tasks (high-priority tasks) have been completed; or +- The marginal benefit of further searching is very low; or +- The maximum number of research iterations has been reached. + +## Phase 3: Report Generation +- Invoke the Reporter sub-agent to generate the report. Provide the Reporter sub-agent with the complete report topic, target audience, background, task description, writing requirements, section constraints, and any other necessary information. + - Do not impose a word-count requirement on the Reporter sub-agent unless the user explicitly requests it; DO NOT ask the Reporter sub-agent to include the Execution Summary (执行摘要) as a separate section in the report. + - The Reporter sub-agent writes the final report to reports/report.md and returns only the execution summary and artifact file paths. The system automatically copies reports/report.md to final_report.md. Therefore, final_report.md may not be listed in the Reporter's Artifacts field. Do not ask the Reporter to create final_report.md directly. Review and edit final_report.md. +- After the Reporter returns, you MUST review the report for quality and accuracy: + - **Read once, then act.** Read the full report content once to form your assessment. For subsequent checks, use file_system---search_file_content or file_system---read_file with start_line/end_line — do NOT re-read the entire report file repeatedly. Always check your conversation history before re-reading a file that may already be in context. + - **Verify first.** Before editing, spot-check factual accuracy, logical consistency, coverage of the user's core questions, and citation–claim alignment against the collected evidence. + - The report MUST comply with the "Quality Constraints" and "Default Report Style" sections. Execution Summary (执行摘要) MUST NOT appear as a chapter in the report body. + - **Edit with justification.** Every substantive change (compression, deletion, restructuring, format conversion) must be driven by a concrete problem — such as factual redundancy, logical disorganization, evidence inconsistency, or style/quality violations. Well-structured content with reasonable depth and detail must be preserved as-is, including its structure, granularity, and length. + - **Do not over-edit.** Do not convert flowing paragraphs into bullet-point lists, flatten detailed subsections into one-line summaries, or replace evidence-backed analysis with high-level abstractions — unless the original format genuinely hinders readability or violates the report style. +- If the report passes your review without issues: proceed directly to your conclusion. Do NOT rewrite it "for polish." +- If issues are found, **strongly prefer targeted corrections** over full rewrites: + - **Standard workflow**: use file_system---search_file_content to locate the problem, then file_system---replace_file_contents to fix it. This is the safest and most precise approach. + - Precision reminder: Punctuation mismatches (e.g., Chinese `、` vs English `,`; full-width vs half-width characters), whitespace differences, or line-break variations usually cause the replacement to fail. + - Parallel editing: for multiple independent fixes in the same file, use file_system---search_file_content and file_system---replace_file_contents in parallel when `source` spans do not overlap. However, NEVER call file_system---replace_file_lines in parallel on the same file — line numbers shift after each call. + - **Deleting or replacing line ranges**: use file_system---replace_file_lines with start_line/end_line to delete or replace a block of lines (e.g., removing an entire section). Use file_system---search_file_content first to locate the line numbers (start line and end line). + - **Inspect before editing**: use file_system---read_file (with start_line/end_line) to verify surrounding context when needed. + - **Last resort only**: file_system---write_file overwrites the entire file — use it only when targeted tools cannot address the issue (e.g., extensive structural reorganization). You must reproduce ALL content valuable to the user. + - WARNING: Full report rewrites may carry high risk of content loss. Do not over-compress the report. Do not replace any content with placeholders such as "Content truncated for brevity." or "This section is stored in xxx file." +- Finally show your conclusions for the entire task in the conversation. + +# Process Constraints +1. Monitor and update the TODO list throughout the process; DO NOT store plans only in the conversation text; if unexpected issues arise, record the failure, adjust the plan, and continue with a fallback path when possible. +2. Do not conduct extended web research or draft the full report yourself. Delegate all large-scale retrieval and report drafting to sub-agents. +3. When evidence is insufficient or conclusions conflict with each other, you must explicitly acknowledge the uncertainty, reflect proactively, and attempt to resolve it using the available tools (including sub-agents), while keeping your research iterations limit in mind. +4. Follow the stopping conditions defined in Phase 2. +5. Avoid redundant tool calls. For example, after todo_list---todo_write, the tool response includes the updated TODO list, so you don't need to call todo_list---todo_read again. Similarly, after todo_list---todo_read, you don't need to call file_system---read_file to read related files again (plan.json, plan.md). + +# Tool Invocation Protocol +- You MUST use the tools under the todo_list server to create, update, and read the TODO list. You MUST NOT use any other tools or services to maintain the TODO list. +- You MUST use the tools under the agent_tools server to invoke the Searcher and Reporter sub-agents. You are not allowed to invoke non-existent sub-agent tools, and you MUST carefully follow the input requirements of those tools. +- When context is unclear (e.g., the Searcher sub-agent's output appears to have lost details, or the Reporter sub-agent's report has issues), you should read, filter, and load evidence using the evidence_store server, ensuring you have sufficient confidence before proceeding to the next step. +- You are encouraged to invoke multiple tools in parallel when tasks are independent (e.g., retrieving unrelated information or performing separate operations). +- For file-level operations, keep using relative paths. + +# Quality Constraints +- NEVER fabricate citations or sources. Every factual statement in the final deliverable must be supported by the Searcher sub-agent's research conclusions and stored evidence. +- Clearly track time constraints and the current date. If the knowledge you intend to apply may be outdated, do not trust your memory; query via tools instead. +- Strictly control scope: if the user asks for X, do not drift to Y. +- Citation integrity in the final report: + - **Fix if broken**: + - Invalid citation forms (e.g., [Note ID]-style placeholders) — replace with proper `[1]`, `[2]`, ... numbered markers. + - Multiple ## References / ## 参考文献 sections (e.g., per-chapter reference lists) are not allowed — this includes any variant headings such as "## 参考文献(合并版)", "## References (Merged)", "## 参考资料", or similar. The report body and individual chapters must NOT contain any reference/bibliography list; remove such sections entirely. Keep only one unified reference section at the very end of the report. Re-number in-text citations if needed. + - **Supplement if missing**: Add numbered citations `[1]`, `[2]`, ... in the body (multiple may appear together like `[1][3]`). The report must end with exactly one `## References` (English) or `## 参考文献` (Chinese) section with consistent numbering. Do not use long-title links (e.g., `[Title](URL)`) in the body text. + - **Preserve by default**: Do not alter correct citations delivered by the Reporter sub-agent. Your edits must not cause citation loss. +- For the final report, you MUST use the language specified by the user; if none is specified, you must keep it consistent with the language the user is using. +- The final report in final_report.md MUST follow the "Default Report Style" section. + +# Default Report Style +- Technical/research report tone: careful and verifiable; as much information as possible and as faithful to the original evidence as possible; do not over-compress into an executive-summary-only output; avoid overly casual language; ensure readability. +- Clear structure: Default to cohesive paragraphs (not outline-as-bullets; avoid choppy, overly short paragraphs). Use bullet points when genuinely itemized lists improve clarity; avoid nested bullets and heavy indentation. +- Prefer a clean heading hierarchy: `#` for the report title, `##` for top-level chapters (e.g., `## 2. Background and Problem`), `###` and `####` for sub-sections. Do not exceed four heading levels. All headings MUST use Markdown ATX syntax. +- Chapter titles you provide should be concise and natural-sounding. Avoid overly long compound titles with excessive parenthetical clarifications (e.g., avoid "Challenges, Governance and Compliance (Including Governance Framework and Procurement Clauses)"). +- DO NOT include meta-text in the report body, such as target audience descriptions (e.g., "Target Audience: ...", "面向对象:..."), author notes (e.g., "Note: ...", "注:..."), or execution disclaimers. The report should be a polished, self-contained document. + +# Unexpected Handling +1. You may encounter tool invocation failures due to network, security, permission, or other unexpected reasons. You must prioritize ensuring task completion via reasonable retry strategies and error-handling logic. +2. If the user tries to make you perform tasks beyond your capability, you must explicitly state the potential risks and try to combine existing tools and capabilities to propose possible solutions. +3. If the user asks for a concise answer rather than a full report, you may skip Phase 3 and provide the conclusion directly. +4. If the user attempts casual conversation rather than research tasks, you do not need to start the research workflow; you may respond normally and try to guide the user to initiate a research task. diff --git a/projects/deep_research/v2/reporter.exp_nosnap.yaml b/projects/deep_research/v2/reporter.exp_nosnap.yaml new file mode 100644 index 000000000..d1c67ccdf --- /dev/null +++ b/projects/deep_research/v2/reporter.exp_nosnap.yaml @@ -0,0 +1,93 @@ +llm: + service: openai + model: qwen3.5-plus + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + # Enable explicit prefix caching (auto-detects provider from openai_base_url) + force_prefix_cache: true + # Supports role names: system, user, assistant, tool, last_message + prefix_cache_roles: [system, user, assistant, tool] + extra_body: + enable_thinking: false + # show_reasoning: true + # reasoning_output: stdout + + +tag: deep-research + + +prompt: + root: prompts/ + agent: reporter + lang: en + family: gpt5 + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - edit_file + - grep + - glob + todo_list: + mcp: false + auto_render_md: true + include: + - todo_write + - todo_read + evidence_store: + mcp: false + evidence_dir: evidence + include: + - load_index + - get_note + - list_notes + - get_analysis + - list_analyses + report_generator: + mcp: false + reports_dir: reports + plugins: + - tools/evidence_tool.py + - tools/report_tool.py + + +handler: time_handler + +code_file: reporter + +callbacks: + - callbacks/reporter_callback + +max_chat_round: 36 + +# Round-aware reminder injected via ReporterCallback.on_generate_response. +round_reminder: + enabled: true + remind_at_round: 34 + +self_reflection: + enabled: true + max_retries: 3 + min_retention_ratio: 0.6 + post_report_guidance_enabled: true + quality_check: + enabled: true + model: qwen3.5-flash + openai_base_url: + openai_api_key: + +tool_call_timeout: 300 + +output_dir: ./output + +# bench experiment: disable snapshot for reporter sub-agent +enable_snapshots: false diff --git a/projects/deep_research/v2/reporter.py b/projects/deep_research/v2/reporter.py index d9c6f2507..a0227ff65 100644 --- a/projects/deep_research/v2/reporter.py +++ b/projects/deep_research/v2/reporter.py @@ -1,12 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +from omegaconf import DictConfig from typing import Any, AsyncGenerator, List, Union from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/deep_research/v2/reporter.yaml b/projects/deep_research/v2/reporter.yaml index c55fd109b..84183cbf7 100644 --- a/projects/deep_research/v2/reporter.yaml +++ b/projects/deep_research/v2/reporter.yaml @@ -34,8 +34,15 @@ tools: include: - write_file - read_file - - list_files - - replace_file_lines + - edit_file + - grep + - glob + todo_list: + mcp: false + auto_render_md: true + include: + - todo_write + - todo_read evidence_store: mcp: false evidence_dir: evidence @@ -71,10 +78,12 @@ self_reflection: enabled: true max_retries: 3 min_retention_ratio: 0.6 - post_report_guidance_enabled: false + post_report_guidance_enabled: true quality_check: enabled: true model: qwen3.5-flash + openai_base_url: + openai_api_key: tool_call_timeout: 300 diff --git a/projects/deep_research/v2/researcher.exp_nosnap.yaml b/projects/deep_research/v2/researcher.exp_nosnap.yaml new file mode 100644 index 000000000..307539e74 --- /dev/null +++ b/projects/deep_research/v2/researcher.exp_nosnap.yaml @@ -0,0 +1,174 @@ +llm: + service: openai + model: gpt-5.2-2025-12-11 + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + # Enable explicit prefix caching (auto-detects provider from openai_base_url) + force_prefix_cache: false + # Supports role names: system, user, assistant, tool, last_message + prefix_cache_roles: [system, user, assistant, tool] + # extra_body: + # enable_thinking: true + # show_reasoning: true + # reasoning_output: stdout + reasoning_effort: medium + + +tag: deep-research-researcher + + +prompt: + root: prompts/ + agent: researcher + lang: en + family: thinking + + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - edit_file + - grep + - glob + code_executor: + mcp: false + implementation: python_env + notebook_timeout: 120 + include: + - notebook_executor + todo_list: + mcp: false + auto_render_md: true + include: + - todo_write + - todo_read + evidence_store: + mcp: false + evidence_dir: evidence + include: + - load_index + - get_note + - list_notes + - write_analysis + - get_analysis + - list_analyses + agent_tools: + mcp: false + enable_stats: true + run_in_thread: true + run_in_process: true + max_workers: 4 + definitions: + - tool_name: searcher_tool + description: > + Invoke the Searcher sub-agent to perform an in-depth research task on a specific topic. + Searcher is capable of autonomously executing a research loop until sufficient evidence is collected and a research report is produced (search -> parse -> evidence discovery & storage -> progressive search -> ...). + Returns a JSON result containing: task completion status, core findings, issues or limitations encountered, research report body, and evidence storage locations. + config_path: searcher.exp_nosnap.yaml + parameters: + type: object + properties: + request: + type: string + description: > + A JSON-formatted research task description that should include: + - The corresponding task ID from the TODO list (required) + - Specific research objectives + - Questions to be answered + - Constraints (time range, source preferences, etc., optional) + - Stopping conditions (optional) + - Other requirements (optional) + Recommended format: + { + "task_id": "...", + "research_objectives": "...", + "questions_to_answer": "...", + "constraints": "...", + "stopping_conditions": "...", + "other_requirements": "...", + } + required: [request] + additionalProperties: false + trust_remote_code: true + output_mode: final_message + max_output_chars: 200000 + - tool_name: reporter_tool + description: > + Invoke the Reporter sub-agent to generate a research report based on collected evidence. + Reporter reads the stored evidence cards and executes a complex workflow for research report writing. + The completed report is automatically saved to `final_report.md` in the output directory by the system. + Returns a JSON result containing: execution summary and + intermediate artifact file paths (the full report body is NOT included in the return value — + read `final_report.md` directly to access the report content). + config_path: reporter.exp_nosnap.yaml + parameters: + type: object + properties: + request: + type: string + description: > + A JSON-formatted report generation instruction that should include: + - Report topic and target audience + - Complete background description and task description + - Core questions to be covered + - Writing requirements (style, structure, length, language, etc.) + - Any other requirements + Recommended format: + { + "report_topic_and_audience": "...", + "background": "...", + "task_description": "...", + "writing_requirements": "...", + "other_requirements": "...", + } + required: [request] + additionalProperties: false + trust_remote_code: true + output_mode: final_message + max_output_chars: 200000 + plugins: + - tools/evidence_tool.py + + +callbacks: + - callbacks/researcher_callback + +# Self-reflection checks before allowing the researcher to stop. +self_reflection: + enabled: true + max_retries: 3 + compression_check: + enabled: true + min_retention_ratio: 0.5 + report_selection: + enabled: true + min_retention_ratio: 0.5 + report_cleanup: + enabled: false + quality_check: + enabled: true + model: qwen3.5-flash + openai_base_url: + openai_api_key: + +handler: time_handler + +code_file: researcher + +max_chat_round: 45 + +tool_call_timeout: 2600 + +output_dir: ./output + +# bench experiment: disable .ms_agent_snapshots git snapshot +enable_snapshots: false diff --git a/projects/deep_research/v2/researcher.py b/projects/deep_research/v2/researcher.py index 6af015717..2b2726368 100644 --- a/projects/deep_research/v2/researcher.py +++ b/projects/deep_research/v2/researcher.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from omegaconf import DictConfig from typing import Any, AsyncGenerator, List, Union from ms_agent.agent.llm_agent import LLMAgent @@ -8,7 +9,6 @@ from ms_agent.utils.stats import (append_stats, build_timing_record, get_stats_path, monotonic, now_iso, summarize_usage) -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/deep_research/v2/researcher.tavily.yaml b/projects/deep_research/v2/researcher.tavily.yaml new file mode 100644 index 000000000..e2bdd654e --- /dev/null +++ b/projects/deep_research/v2/researcher.tavily.yaml @@ -0,0 +1,171 @@ +llm: + service: openai + model: gpt-5.2-2025-12-11 + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + # Enable explicit prefix caching (auto-detects provider from openai_base_url) + force_prefix_cache: false + # Supports role names: system, user, assistant, tool, last_message + prefix_cache_roles: [system, user, assistant, tool] + # extra_body: + # enable_thinking: true + # show_reasoning: true + # reasoning_output: stdout + reasoning_effort: medium + + +tag: deep-research-researcher + + +prompt: + root: prompts/ + agent: researcher + lang: en + family: thinking + + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - edit_file + - grep + - glob + code_executor: + mcp: false + implementation: python_env + notebook_timeout: 120 + include: + - notebook_executor + todo_list: + mcp: false + auto_render_md: true + include: + - todo_write + - todo_read + evidence_store: + mcp: false + evidence_dir: evidence + include: + - load_index + - get_note + - list_notes + - write_analysis + - get_analysis + - list_analyses + agent_tools: + mcp: false + enable_stats: true + run_in_thread: true + run_in_process: true + max_workers: 4 + definitions: + - tool_name: searcher_tool + description: > + Invoke the Searcher sub-agent to perform an in-depth research task on a specific topic. + Searcher is capable of autonomously executing a research loop until sufficient evidence is collected and a research report is produced (search -> parse -> evidence discovery & storage -> progressive search -> ...). + Returns a JSON result containing: task completion status, core findings, issues or limitations encountered, research report body, and evidence storage locations. + config_path: searcher.tavily.yaml + parameters: + type: object + properties: + request: + type: string + description: > + A JSON-formatted research task description that should include: + - The corresponding task ID from the TODO list (required) + - Specific research objectives + - Questions to be answered + - Constraints (time range, source preferences, etc., optional) + - Stopping conditions (optional) + - Other requirements (optional) + Recommended format: + { + "task_id": "...", + "research_objectives": "...", + "questions_to_answer": "...", + "constraints": "...", + "stopping_conditions": "...", + "other_requirements": "...", + } + required: [request] + additionalProperties: false + trust_remote_code: true + output_mode: final_message + max_output_chars: 200000 + - tool_name: reporter_tool + description: > + Invoke the Reporter sub-agent to generate a research report based on collected evidence. + Reporter reads the stored evidence cards and executes a complex workflow for research report writing. + The completed report is automatically saved to `final_report.md` in the output directory by the system. + Returns a JSON result containing: execution summary and + intermediate artifact file paths (the full report body is NOT included in the return value — + read `final_report.md` directly to access the report content). + config_path: reporter.yaml + parameters: + type: object + properties: + request: + type: string + description: > + A JSON-formatted report generation instruction that should include: + - Report topic and target audience + - Complete background description and task description + - Core questions to be covered + - Writing requirements (style, structure, length, language, etc.) + - Any other requirements + Recommended format: + { + "report_topic_and_audience": "...", + "background": "...", + "task_description": "...", + "writing_requirements": "...", + "other_requirements": "...", + } + required: [request] + additionalProperties: false + trust_remote_code: true + output_mode: final_message + max_output_chars: 200000 + plugins: + - tools/evidence_tool.py + + +callbacks: + - callbacks/researcher_callback + +# Self-reflection checks before allowing the researcher to stop. +self_reflection: + enabled: true + max_retries: 3 + compression_check: + enabled: true + min_retention_ratio: 0.5 + report_selection: + enabled: true + min_retention_ratio: 0.5 + report_cleanup: + enabled: false + quality_check: + enabled: true + model: qwen3.5-flash + openai_base_url: + openai_api_key: + +handler: time_handler + +code_file: researcher + +max_chat_round: 45 + +tool_call_timeout: 2600 + +output_dir: ./output diff --git a/projects/deep_research/v2/researcher.yaml b/projects/deep_research/v2/researcher.yaml index 50c2ece5d..f19974482 100644 --- a/projects/deep_research/v2/researcher.yaml +++ b/projects/deep_research/v2/researcher.yaml @@ -1,6 +1,6 @@ llm: service: openai - model: gpt-5-2025-08-07 + model: gpt-5.2-2025-12-11 openai_api_key: openai_base_url: @@ -14,7 +14,10 @@ generation_config: # Supports role names: system, user, assistant, tool, last_message prefix_cache_roles: [system, user, assistant, tool] # extra_body: - # enable_thinking: false + # enable_thinking: true + # show_reasoning: true + # reasoning_output: stdout + reasoning_effort: medium tag: deep-research-researcher @@ -24,7 +27,7 @@ prompt: root: prompts/ agent: researcher lang: en - family: gpt5 + family: thinking tools: @@ -33,10 +36,9 @@ tools: include: - write_file - read_file - - list_files - - search_file_content - - replace_file_contents - - replace_file_lines + - edit_file + - grep + - glob code_executor: mcp: false implementation: python_env @@ -63,6 +65,7 @@ tools: mcp: false enable_stats: true run_in_thread: true + run_in_process: true max_workers: 4 definitions: - tool_name: searcher_tool @@ -102,7 +105,7 @@ tools: description: > Invoke the Reporter sub-agent to generate a research report based on collected evidence. Reporter reads the stored evidence cards and executes a complex workflow for research report writing. - The completed report is automatically saved to `final_report.md` in the output directory. + The completed report is automatically saved to `final_report.md` in the output directory by the system. Returns a JSON result containing: execution summary and intermediate artifact file paths (the full report body is NOT included in the return value — read `final_report.md` directly to access the report content). @@ -140,20 +143,29 @@ callbacks: - callbacks/researcher_callback # Self-reflection checks before allowing the researcher to stop. -# Runs inside ResearcherCallback.after_tool_call. self_reflection: enabled: true - max_retries: 2 + max_retries: 3 + compression_check: + enabled: true + min_retention_ratio: 0.5 + report_selection: + enabled: true + min_retention_ratio: 0.5 + report_cleanup: + enabled: false quality_check: enabled: true model: qwen3.5-flash + openai_base_url: + openai_api_key: handler: time_handler code_file: researcher -max_chat_round: 42 +max_chat_round: 45 -tool_call_timeout: 2400 +tool_call_timeout: 2600 output_dir: ./output diff --git a/projects/deep_research/v2/run_benchmark.sh b/projects/deep_research/v2/run_benchmark.sh index 97c451d4f..d4c147f84 100755 --- a/projects/deep_research/v2/run_benchmark.sh +++ b/projects/deep_research/v2/run_benchmark.sh @@ -30,6 +30,11 @@ else echo -e "${RED}Error: Neither 'python' nor 'python3' is available in PATH.${NC}" exit 1 fi +PYTHON_BIN="/Users/luyan/software/miniconda3/bin/python" + +# When stdout is redirected (e.g., nohup > file), Python is block-buffered by default. +# Force unbuffered output so progress lines like "[xx] OK" show up in logs promptly. +export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}" # Use caffeinate on macOS when available; otherwise run normally. RUN_PREFIX=() @@ -87,19 +92,22 @@ if [ -z "$DR_BENCH_ROOT" ]; then echo -e "${YELLOW}Using default benchmark query...${NC}" echo "" - # Run a simple benchmark query - QUERY="Provide a comprehensive survey of recent advances in large language models (LLMs), covering key developments in the last 12 months including architecture innovations, training techniques, and real-world applications." - OUTPUT_DIR="output/deep_research/benchmark_run" + # Run a simple benchmark query (override with BENCH_QUERY for smoke tests) + DEFAULT_QUERY="Provide a comprehensive survey of recent advances in large language models (LLMs), covering key developments in the last 12 months including architecture innovations, training techniques, and real-world applications." + QUERY="${BENCH_QUERY:-$DEFAULT_QUERY}" + OUTPUT_DIR="${BENCH_OUTPUT_DIR:-output/deep_research/benchmark_run}" echo -e "${GREEN}Running benchmark with query:${NC}" echo " \"$QUERY\"" echo "" echo -e "${GREEN}Output directory: $OUTPUT_DIR${NC}" + RESEARCHER_CONFIG="${RESEARCHER_CONFIG:-projects/deep_research/v2/researcher.yaml}" + echo -e "${GREEN}Researcher config: $RESEARCHER_CONFIG${NC}" echo "" - # Run the benchmark - PYTHONPATH=. "$PYTHON_BIN" ms_agent/cli/cli.py run \ - --config projects/deep_research/v2/researcher.yaml \ + # Run the benchmark (override RESEARCHER_CONFIG / BENCH_OUTPUT_DIR as needed) + PYTHONPATH=. "$PYTHON_BIN" -u ms_agent/cli/cli.py run \ + --config "$RESEARCHER_CONFIG" \ --query "$QUERY" \ --trust_remote_code true \ --output_dir "$OUTPUT_DIR" @@ -118,7 +126,7 @@ else # Benchmark subprocess tuning (override via env vars if needed) export DR_BENCH_POST_FINISH_GRACE_S="${DR_BENCH_POST_FINISH_GRACE_S:-180}" - export DR_BENCH_POST_REPORT_EXIT_GRACE_S="${DR_BENCH_POST_REPORT_EXIT_GRACE_S:-3600}" + export DR_BENCH_POST_REPORT_EXIT_GRACE_S="${DR_BENCH_POST_REPORT_EXIT_GRACE_S:-7200}" export DR_BENCH_REPORT_STABLE_WINDOW_S="${DR_BENCH_REPORT_STABLE_WINDOW_S:-10}" export DR_BENCH_SUBPROCESS_POLL_INTERVAL_S="${DR_BENCH_SUBPROCESS_POLL_INTERVAL_S:-0.5}" export DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S="${DR_BENCH_SUBPROCESS_TERMINATE_TIMEOUT_S:-30}" @@ -161,13 +169,16 @@ else echo " Work root: $WORK_ROOT" echo " Workers: $WORKERS" echo " Limit: $LIMIT (0 = no limit)" + RESEARCHER_CONFIG="${RESEARCHER_CONFIG:-projects/deep_research/v2/researcher.yaml}" + echo " Researcher config: $RESEARCHER_CONFIG" echo "" # Run the full benchmark - PYTHONPATH=. "${RUN_PREFIX[@]}" "$PYTHON_BIN" projects/deep_research/v2/eval/dr_bench_runner.py \ + PYTHONPATH=. "${RUN_PREFIX[@]}" "$PYTHON_BIN" -u projects/deep_research/v2/eval/dr_bench_runner.py \ --query_file "$QUERY_FILE" \ --output_jsonl "$OUTPUT_JSONL" \ --model_name "$MODEL_NAME" \ + --config "$RESEARCHER_CONFIG" \ --work_root "$WORK_ROOT" \ --limit "$LIMIT" \ --workers "$WORKERS" \ diff --git a/projects/deep_research/v2/searcher.exp_nosnap.yaml b/projects/deep_research/v2/searcher.exp_nosnap.yaml new file mode 100644 index 000000000..4858fb281 --- /dev/null +++ b/projects/deep_research/v2/searcher.exp_nosnap.yaml @@ -0,0 +1,110 @@ +# Tavily-only web search (no Exa / no Jina Reader). +# Use: cp projects/deep_research/v2/searcher.tavily.yaml projects/deep_research/v2/searcher.yaml +# Or point researcher.yaml searcher_tool config_path to this file. +# Requires TAVILY_API_KEY in .env. + +llm: + service: openai + model: qwen3.5-plus + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + force_prefix_cache: true + prefix_cache_roles: [system, user, assistant, tool] + extra_body: + enable_thinking: false + + +tag: deep-research + + +prompt: + root: prompts/ + agent: searcher + lang: en + family: gpt5 + + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - edit_file + - grep + - glob + web_search: + mcp: false + engines: + - tavily + tavily_api_key: + max_results: 10 + fetcher: tavily_extract + fetch_content: true + fetch_timeout: 60 + fetch_retries: 3 + per_url_fetch_timeout: 400 + _max_concurrent_fetch: 5 + tavily: + search_depth: advanced + include_raw_content: markdown + include_answer: advanced + chunks_per_source: 3 + max_results: 10 + tavily_request_timeout: 120 + tavily_extract_depth: advanced + tavily_extract_format: markdown + enable_chunking: false + # Off by default: Tavily raw_content is enough; summarizer may hit provider filters. + enable_summarization: false + summarizer_model: qwen3.5-flash + summarizer_base_url: + summarizer_api_key: + max_content_chars: 200000 + summarizer_max_workers: 15 + summarization_timeout: 360 + _max_concurrent_summarization: 15 + # Optional: spill oversized tool JSON to output_dir/web_search_artifacts/… + # spill_large_results: true + # spill_max_inline_chars: 120000 + # spill_preview_chars: 600 + # spill_subdir: web_search_artifacts + evidence_store: + mcp: false + evidence_dir: evidence + chunks_dir: chunks + enable_chunk_storage: false + include: + - write_note + - list_notes + - get_note + - load_index + - search_notes + - delete_note + plugins: + - tools/evidence_tool.py + + +handler: time_handler + +callbacks: + - callbacks/searcher_callback + +max_chat_round: 30 + +round_reminder: + enabled: true + remind_at_round: 28 + +tool_call_timeout: 2600 + +output_dir: ./output + +# bench experiment: disable snapshot for searcher sub-agent +enable_snapshots: false diff --git a/projects/deep_research/v2/searcher.tavily.yaml b/projects/deep_research/v2/searcher.tavily.yaml new file mode 100644 index 000000000..b24887dae --- /dev/null +++ b/projects/deep_research/v2/searcher.tavily.yaml @@ -0,0 +1,107 @@ +# Tavily-only web search (no Exa / no Jina Reader). +# Use: cp projects/deep_research/v2/searcher.tavily.yaml projects/deep_research/v2/searcher.yaml +# Or point researcher.yaml searcher_tool config_path to this file. +# Requires TAVILY_API_KEY in .env. + +llm: + service: openai + model: qwen3.5-plus + openai_api_key: + openai_base_url: + + +generation_config: + stream: true + stream_options: + include_usage: true + force_prefix_cache: true + prefix_cache_roles: [system, user, assistant, tool] + extra_body: + enable_thinking: false + + +tag: deep-research + + +prompt: + root: prompts/ + agent: searcher + lang: en + family: gpt5 + + +tools: + file_system: + mcp: false + include: + - write_file + - read_file + - edit_file + - grep + - glob + web_search: + mcp: false + engines: + - tavily + tavily_api_key: + max_results: 10 + fetcher: tavily_extract + fetch_content: true + fetch_timeout: 60 + fetch_retries: 3 + per_url_fetch_timeout: 400 + _max_concurrent_fetch: 5 + tavily: + search_depth: advanced + include_raw_content: markdown + include_answer: advanced + chunks_per_source: 3 + max_results: 10 + tavily_request_timeout: 120 + tavily_extract_depth: advanced + tavily_extract_format: markdown + enable_chunking: false + # Off by default: Tavily raw_content is enough; summarizer may hit provider filters. + enable_summarization: false + summarizer_model: qwen3.5-flash + summarizer_base_url: + summarizer_api_key: + max_content_chars: 200000 + summarizer_max_workers: 15 + summarization_timeout: 360 + _max_concurrent_summarization: 15 + # Optional: spill oversized tool JSON to output_dir/web_search_artifacts/… + # spill_large_results: true + # spill_max_inline_chars: 120000 + # spill_preview_chars: 600 + # spill_subdir: web_search_artifacts + evidence_store: + mcp: false + evidence_dir: evidence + chunks_dir: chunks + enable_chunk_storage: false + include: + - write_note + - list_notes + - get_note + - load_index + - search_notes + - delete_note + plugins: + - tools/evidence_tool.py + + +handler: time_handler + +callbacks: + - callbacks/searcher_callback + +max_chat_round: 30 + +round_reminder: + enabled: true + remind_at_round: 28 + +tool_call_timeout: 2600 + +output_dir: ./output diff --git a/projects/deep_research/v2/searcher.yaml b/projects/deep_research/v2/searcher.yaml index b9b19f08b..b24887dae 100644 --- a/projects/deep_research/v2/searcher.yaml +++ b/projects/deep_research/v2/searcher.yaml @@ -1,3 +1,8 @@ +# Tavily-only web search (no Exa / no Jina Reader). +# Use: cp projects/deep_research/v2/searcher.tavily.yaml projects/deep_research/v2/searcher.yaml +# Or point researcher.yaml searcher_tool config_path to this file. +# Requires TAVILY_API_KEY in .env. + llm: service: openai model: qwen3.5-plus @@ -9,9 +14,7 @@ generation_config: stream: true stream_options: include_usage: true - # Enable explicit prefix caching (auto-detects provider from openai_base_url) force_prefix_cache: true - # Supports role names: system, user, assistant, tool, last_message prefix_cache_roles: [system, user, assistant, tool] extra_body: enable_thinking: false @@ -33,20 +36,33 @@ tools: include: - write_file - read_file - - list_files + - edit_file + - grep + - glob web_search: mcp: false engines: - - exa - - arxiv - api_key: - max_results: 5 - fetcher: jina_reader + - tavily + tavily_api_key: + max_results: 10 + fetcher: tavily_extract fetch_content: true fetch_timeout: 60 + fetch_retries: 3 + per_url_fetch_timeout: 400 _max_concurrent_fetch: 5 + tavily: + search_depth: advanced + include_raw_content: markdown + include_answer: advanced + chunks_per_source: 3 + max_results: 10 + tavily_request_timeout: 120 + tavily_extract_depth: advanced + tavily_extract_format: markdown enable_chunking: false - enable_summarization: true + # Off by default: Tavily raw_content is enough; summarizer may hit provider filters. + enable_summarization: false summarizer_model: qwen3.5-flash summarizer_base_url: summarizer_api_key: @@ -54,6 +70,11 @@ tools: summarizer_max_workers: 15 summarization_timeout: 360 _max_concurrent_summarization: 15 + # Optional: spill oversized tool JSON to output_dir/web_search_artifacts/… + # spill_large_results: true + # spill_max_inline_chars: 120000 + # spill_preview_chars: 600 + # spill_subdir: web_search_artifacts evidence_store: mcp: false evidence_dir: evidence @@ -77,11 +98,10 @@ callbacks: max_chat_round: 30 -# Round-aware reminder injected via SearcherCallback.on_generate_response. round_reminder: enabled: true remind_at_round: 28 -tool_call_timeout: 300 +tool_call_timeout: 2600 output_dir: ./output diff --git a/projects/deep_research/v2/time_handler.py b/projects/deep_research/v2/time_handler.py index a36282b5f..2240ee515 100644 --- a/projects/deep_research/v2/time_handler.py +++ b/projects/deep_research/v2/time_handler.py @@ -1,9 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from datetime import datetime +from omegaconf import DictConfig from typing import Any from ms_agent.config.config import ConfigLifecycleHandler -from omegaconf import DictConfig class TimeHandler(ConfigLifecycleHandler): diff --git a/projects/deep_research/v2/tools/evidence_tool.py b/projects/deep_research/v2/tools/evidence_tool.py index 064ab6a87..dd6866f4b 100644 --- a/projects/deep_research/v2/tools/evidence_tool.py +++ b/projects/deep_research/v2/tools/evidence_tool.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import time import uuid from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock @@ -372,14 +372,15 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'Each note represents ONE piece of evidence: a claim/observation with supporting text. ' 'Returns the generated note_id.'), parameters={ - 'type': - 'object', + 'type': 'object', 'properties': { 'title': { 'type': 'string', 'description': - 'Brief title describing this evidence (e.g., "Tesla Q3 revenue growth").', + ('Brief title describing this evidence (e.g., "Tesla Q3 revenue growth"). ' + 'Optional: if omitted, a title is derived from the first line of `content`.' + ), }, 'content': { 'type': @@ -464,12 +465,8 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'Optional: Confidence/quality score (0-100).', }, }, - 'required': [ - 'title', 'content', 'sources', 'summary', - 'task_id', 'tags' - ], - 'additionalProperties': - False, + 'required': ['content'], + 'additionalProperties': False, }, ), Tool( @@ -849,8 +846,8 @@ def _load_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: async def write_note( self, - title: str, content: str, + title: Optional[str] = None, contradicts: Optional[str] = None, sources: Optional[List[Dict[str, Any]]] = None, summary: Optional[str] = None, @@ -863,12 +860,29 @@ async def write_note( _ensure_dir(paths['notes_dir']) _ensure_dir(paths['lock_dir']) + content = (content or '').strip() + if not content: + return _json_dumps({ + 'status': + 'error', + 'message': + 'write_note requires non-empty content.', + }) + + if title is None or not str(title).strip(): + first_line = content.split('\n', 1)[0].strip() + if len(first_line) > 120: + first_line = first_line[:117] + '...' + title_resolved = first_line or 'Evidence note' + else: + title_resolved = str(title).strip() + # Generate ID and build note note_id = _generate_note_id() note: Dict[str, Any] = { 'note_id': note_id, - 'title': title.strip(), - 'content': content.strip(), + 'title': title_resolved, + 'content': content, 'created_at': _now_iso(), } diff --git a/projects/deep_research/v2/tools/report_tool.py b/projects/deep_research/v2/tools/report_tool.py index 96fd51994..924b49804 100644 --- a/projects/deep_research/v2/tools/report_tool.py +++ b/projects/deep_research/v2/tools/report_tool.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import os import re import time import uuid from typing import Any, Dict, List, Optional -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils.utils import file_lock, render_markdown_todo @@ -43,6 +43,41 @@ def _write_text(path: str, content: str) -> None: f.write(content) +def _coerce_chapters_argument( + chapters: Any) -> tuple[List[Dict[str, Any]], Optional[str]]: + """Normalize `chapters` from the model (list, JSON string, or nested strings).""" + if chapters is None: + return [], ( + 'commit_outline requires `chapters` (array of chapter objects, ' + 'or a JSON string of that array).') + raw: Any = chapters + if isinstance(raw, str): + try: + raw = json.loads(raw.strip()) + except json.JSONDecodeError as e: + return [], ( + 'commit_outline `chapters` must be a JSON array of objects, ' + f'or a JSON string of that array: {e}') + if not isinstance(raw, list): + return [], ( + f'commit_outline `chapters` must be a list, got {type(chapters).__name__}.' + ) + out: List[Dict[str, Any]] = [] + for i, ch in enumerate(raw): + if isinstance(ch, str): + try: + ch = json.loads(ch.strip()) + except json.JSONDecodeError: + return [], (f'commit_outline chapters[{i}] must be an object; ' + 'string entry is not valid JSON for an object.') + if not isinstance(ch, dict): + return [], ( + f'commit_outline chapters[{i}] must be an object, got {type(ch).__name__}.' + ) + out.append(ch) + return out, None + + def _render_outline_md(outline: Dict[str, Any]) -> str: """Render outline as Markdown.""" lines = [f"# {outline.get('title', 'Report Outline')}", ''] @@ -178,6 +213,37 @@ def _paths(self) -> Dict[str, str]: os.path.join(self.output_dir, self._lock_subdir), } + def _filter_candidate_evidence( + self, + paths: Dict[str, str], + candidate: Any, + ) -> tuple[List[str], List[str]]: + """Keep only note ids that have ``note_{id}.md`` under evidence notes dir. + + Returns: + (kept_ids_in_order, dropped_ids) — duplicates in input are skipped + after the first occurrence. + """ + if not isinstance(candidate, list): + return [], [] + notes_dir = paths['evidence_notes_dir'] + kept: List[str] = [] + dropped: List[str] = [] + seen: set[str] = set() + for raw in candidate: + nid = str(raw).strip() if raw is not None else '' + if not nid: + continue + if nid in seen: + continue + seen.add(nid) + note_path = os.path.join(notes_dir, f'note_{nid}.md') + if os.path.exists(note_path): + kept.append(nid) + else: + dropped.append(nid) + return kept, dropped + async def _get_tools_inner(self) -> Dict[str, Any]: tools: Dict[str, List[Tool]] = { self.SERVER_NAME: [ @@ -500,6 +566,21 @@ async def _get_tools_inner(self) -> Dict[str, Any]: 'additionalProperties': False, }, ), + Tool( + tool_name='load_index', + server_name=self.SERVER_NAME, + description= + ('Load the full evidence index (notes + analyses metadata). ' + 'Same data as evidence_store---load_index; provided here so calls ' + 'mistakenly prefixed with report_generator--- still work.' + ), + parameters={ + 'type': 'object', + 'properties': {}, + 'required': [], + 'additionalProperties': False, + }, + ), ] } return tools @@ -536,6 +617,27 @@ def _load_evidence_index(self, paths: Dict[str, str]) -> Dict[str, Any]: return {'notes': {}} return data + def _load_full_evidence_index(self, paths: Dict[str, + str]) -> Dict[str, Any]: + """Load evidence/index.json with the same defaults as EvidenceTool.""" + data = _safe_read_json(paths['evidence_index']) + if data is None or not isinstance(data, dict): + return { + 'schema_version': 2, + 'updated_at': _now_iso(), + 'notes': {}, + 'analyses': {}, + } + if 'notes' not in data or not isinstance(data.get('notes'), dict): + data['notes'] = {} + if 'analyses' not in data or not isinstance( + data.get('analyses'), dict): + data['analyses'] = {} + legacy = data.get('conclusions') + if isinstance(legacy, dict) and legacy and not data.get('analyses'): + data['analyses'] = legacy + return data + def _load_note_content(self, paths: Dict[str, str], note_id: str) -> Optional[Dict[str, Any]]: """Load a single note's full content from markdown file.""" @@ -565,27 +667,53 @@ def _save_conflict(self, paths: Dict[str, str], conflict['updated_at'] = _now_iso() _write_text(paths['conflict_json'], _json_dumps(conflict)) + async def load_index(self) -> str: + """Return evidence index JSON (alias for mistaken report_generator---load_index).""" + paths = self._paths() + _ensure_dir(paths['lock_dir']) + with file_lock(paths['lock_dir'], 'evidence_index'): + index = self._load_full_evidence_index(paths) + notes = index.get('notes', {}) + analyses = index.get('analyses', {}) + return _json_dumps({ + 'status': 'ok', + 'updated_at': index.get('updated_at', ''), + 'total_notes': len(notes), + 'total_analyses': len(analyses), + 'notes': notes, + 'analyses': analyses, + }) + async def commit_outline( self, title: str, - chapters: List[Dict[str, Any]], + chapters: Any, ) -> str: """Generate report outline with chapter structure.""" paths = self._paths() _ensure_dir(paths['chapters_dir']) _ensure_dir(paths['lock_dir']) + chapters_list, chapters_err = _coerce_chapters_argument(chapters) + if chapters_err: + return _json_dumps({'status': 'error', 'message': chapters_err}) + # Load evidence index to validate coverage evidence_index = self._load_evidence_index(paths) all_note_ids = set(evidence_index.get('notes', {}).keys()) - # Build outline + # Build outline (only bind note ids that exist on disk) outline_chapters = [] covered_evidence = set() + invalid_candidate_by_chapter: Dict[str, List[str]] = {} - for idx, ch in enumerate(chapters, start=1): - candidate = ch.get('candidate_evidence', []) - covered_evidence.update(candidate) + for idx, ch in enumerate(chapters_list, start=1): + candidate_raw = ch.get('candidate_evidence', []) + kept, dropped = self._filter_candidate_evidence( + paths, candidate_raw) + if dropped: + invalid_candidate_by_chapter[str(idx)] = dropped + covered_evidence.update(kept) outline_chapters.append({ 'chapter_id': @@ -597,7 +725,7 @@ async def commit_outline( 'sections_description': ch.get('sections_description', ''), 'candidate_evidence': - candidate, + kept, 'status': 'pending', }) @@ -632,6 +760,14 @@ async def commit_outline( if coverage_warning: result['warning'] = coverage_warning + if invalid_candidate_by_chapter: + result['invalid_candidate_evidence'] = invalid_candidate_by_chapter + result['invalid_candidate_evidence_note'] = ( + 'These note ids were removed from candidate_evidence because ' + 'no matching evidence/notes/note_.md file exists. ' + 'Use list_notes or prior write_note responses to pick valid ids.' + ) + return _json_dumps(result) async def prepare_chapter_bundle( @@ -667,13 +803,41 @@ async def prepare_chapter_bundle( 'message': f'Chapter {chapter_id} not found.' }) + cand_kept, cand_dropped = self._filter_candidate_evidence( + paths, chapter.get('candidate_evidence', [])) + rel_kept, rel_dropped = self._filter_candidate_evidence( + paths, relevant_evidence or []) + # Load evidence content evidence_index = self._load_evidence_index(paths) notes_meta = evidence_index.get('notes', {}) + _known_sorted = sorted(notes_meta.keys()) + _sample = _known_sorted[:48] + _note_id_hint = ('Known note ids in evidence index (sample): ' + + (', '.join(_sample) if _sample else '(none)')) + if len(_known_sorted) > len(_sample): + _note_id_hint += f' … (+{len(_known_sorted) - len(_sample)} more)' + + def _missing_note_entry(note_id: str, + meta: Dict[str, Any]) -> Dict[str, Any]: + return { + 'note_id': + note_id, + 'error': + f'Note {note_id} not found', + 'title': + meta.get('title', ''), + 'summary': + meta.get('summary', ''), + 'hint': + (f'{_note_id_hint}. ' + 'Align outline candidate_evidence with filenames under evidence_store, ' + 'or list existing notes before referencing ids.'), + } notes_content = [] seen_note_ids = set() - for note_id in chapter.get('candidate_evidence', []): + for note_id in cand_kept: meta = notes_meta.get(note_id, {}) note_data = self._load_note_content(paths, note_id) @@ -697,17 +861,11 @@ async def prepare_chapter_bundle( meta.get('tags', note_data.get('tags', [])), }) else: - # Note not found, include minimal info - notes_content.append({ - 'note_id': note_id, - 'error': f'Note {note_id} not found', - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - }) + notes_content.append(_missing_note_entry(note_id, meta)) seen_note_ids.add(note_id) - for note_id in relevant_evidence: + for note_id in rel_kept: if note_id not in seen_note_ids: meta = notes_meta.get(note_id, {}) note_data = self._load_note_content(paths, note_id) @@ -733,18 +891,11 @@ async def prepare_chapter_bundle( meta.get('tags', note_data.get('tags', [])), }) else: - notes_content.append({ - 'note_id': note_id, - 'error': f'Note {note_id} not found', - 'title': meta.get('title', ''), - 'summary': meta.get('summary', ''), - }) + notes_content.append(_missing_note_entry(note_id, meta)) - # Build meta + # Build meta (only ids that resolved to on-disk notes for this bundle) candidate_evidence = list( - dict.fromkeys( - list(chapter.get('candidate_evidence', [])) - + list(relevant_evidence or []))) + dict.fromkeys(list(cand_kept) + list(rel_kept))) meta = { 'chapter_id': chapter_id, 'chapter_title': chapter['title'], @@ -767,22 +918,27 @@ async def prepare_chapter_bundle( with file_lock(paths['lock_dir'], 'report_outline'): self._save_outline(paths, outline) - return _json_dumps({ - 'status': - 'ok', - 'chapter_id': - chapter_id, - 'chapter_title': - chapter['title'], - 'chapter_goals': - chapter.get('goals', []), - 'evidence_count': - len(notes_content), - 'meta_path': - os.path.relpath(meta_path, self.output_dir), - 'notes_content': - notes_content, - }) + out_bundle: Dict[str, Any] = { + 'status': 'ok', + 'chapter_id': chapter_id, + 'chapter_title': chapter['title'], + 'chapter_goals': chapter.get('goals', []), + 'evidence_count': len(notes_content), + 'meta_path': os.path.relpath(meta_path, self.output_dir), + 'notes_content': notes_content, + } + skipped: Dict[str, List[str]] = {} + if cand_dropped: + skipped['outline_candidate_evidence'] = cand_dropped + if rel_dropped: + skipped['relevant_evidence_argument'] = rel_dropped + if skipped: + out_bundle['skipped_invalid_note_ids'] = skipped + out_bundle['skipped_invalid_note_ids_note'] = ( + 'These ids were ignored (no evidence/notes/note_.md). ' + 'Update outline candidate_evidence via update_outline if needed.' + ) + return _json_dumps(out_bundle) async def commit_chapter( self, @@ -922,6 +1078,7 @@ async def update_outline( }) chapter_found = False + invalid_candidate_removed: List[str] = [] for ch in outline.get('chapters', []): if ch['chapter_id'] == chapter_id: if 'title' in updates: @@ -932,8 +1089,10 @@ async def update_outline( ch['sections_description'] = updates[ 'sections_description'] if 'candidate_evidence' in updates: - ch['candidate_evidence'] = updates[ - 'candidate_evidence'] + kept, dropped = self._filter_candidate_evidence( + paths, updates['candidate_evidence']) + ch['candidate_evidence'] = kept + invalid_candidate_removed = dropped chapter_found = True break @@ -947,11 +1106,17 @@ async def update_outline( self._save_outline(paths, outline) - return _json_dumps({ + out: Dict[str, Any] = { 'status': 'ok', 'chapter_id': chapter_id, 'updates_applied': list(updates.keys()), - }) + } + if invalid_candidate_removed: + out['invalid_candidate_evidence_removed'] = invalid_candidate_removed + out['invalid_candidate_evidence_note'] = ( + 'These ids were removed from candidate_evidence because no ' + 'matching evidence/notes/note_.md exists.') + return _json_dumps(out) async def assemble_draft( self, diff --git a/projects/fin_research/aggregator.py b/projects/fin_research/aggregator.py index 2380e58e3..b9e8fca5a 100644 --- a/projects/fin_research/aggregator.py +++ b/projects/fin_research/aggregator.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os +from callbacks.file_parser import extract_code_blocks +from omegaconf import DictConfig from typing import Any, AsyncGenerator, List, Union -import json -from callbacks.file_parser import extract_code_blocks from ms_agent.agent.llm_agent import LLMAgent from ms_agent.llm.utils import Message from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_TAG -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/aggregator.yaml b/projects/fin_research/aggregator.yaml index 2efcabfe8..15cc8452e 100644 --- a/projects/fin_research/aggregator.yaml +++ b/projects/fin_research/aggregator.yaml @@ -84,8 +84,8 @@ prompt: - You may interpret the meaning of embedded images from their surrounding context \ in the financial data analysis report. - You should note that most of the charts generated during the analysis process \ - are stored in the default working directory. You may use the tools under the file_system \ - server to browse these files and select meaningful images — based on their filenames — to embed into the report. + are stored in the default working directory. Use `file_system---glob` (e.g. pattern `**/*.{png,jpg,svg}`) \ + or `file_system---read_file` to inspect filenames and embed meaningful images into the report. - Avoid inserting any images whose titles or filenames do not convey meaningful information. - Avoid inserting any images that are from websites or other external sources. - The image paths should be relative paths to the the default working directory. @@ -127,13 +127,22 @@ prompt: tools: + code_executor: + mcp: false + implementation: python_env + exclude: + - notebook_executor + - python_executor + - reset_executor + - get_executor_info file_system: mcp: false include: - read_file - write_file - - list_files - - delete_file_or_dir + - edit_file + - grep + - glob spec_loader: mcp: false plugins: diff --git a/projects/fin_research/callbacks/aggregator_callback.py b/projects/fin_research/callbacks/aggregator_callback.py index b827f18a6..7faf6513f 100644 --- a/projects/fin_research/callbacks/aggregator_callback.py +++ b/projects/fin_research/callbacks/aggregator_callback.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os import re +from omegaconf import DictConfig from typing import List from ms_agent.agent.runtime import Runtime @@ -8,7 +9,6 @@ from ms_agent.llm.utils import Message from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/callbacks/analyst_callback.py b/projects/fin_research/callbacks/analyst_callback.py index 31af7306b..8b0b0f4e6 100644 --- a/projects/fin_research/callbacks/analyst_callback.py +++ b/projects/fin_research/callbacks/analyst_callback.py @@ -1,15 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re +from omegaconf import DictConfig from pathlib import Path from typing import List -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/callbacks/collector_callback.py b/projects/fin_research/callbacks/collector_callback.py index caffbcb71..c8c0eadb4 100644 --- a/projects/fin_research/callbacks/collector_callback.py +++ b/projects/fin_research/callbacks/collector_callback.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os +from omegaconf import DictConfig from pathlib import Path from typing import List -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/callbacks/orchestrator_callback.py b/projects/fin_research/callbacks/orchestrator_callback.py index a649504ab..7b684841a 100644 --- a/projects/fin_research/callbacks/orchestrator_callback.py +++ b/projects/fin_research/callbacks/orchestrator_callback.py @@ -1,13 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os +from file_parser import extract_code_blocks +from omegaconf import DictConfig from typing import List -from file_parser import extract_code_blocks from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback from ms_agent.llm.utils import Message from ms_agent.tools.filesystem_tool import FileSystemTool from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -31,7 +32,7 @@ async def on_task_end(self, runtime: Runtime, messages: List[Message]): if messages[-1].tool_calls or messages[-1].role == 'tool': return - await self.file_system.create_directory() + os.makedirs(self.file_system.output_dir, exist_ok=True) content = '\n'.join([m.content for m in messages[2:]]) all_files, _ = extract_code_blocks(content) results = [] diff --git a/projects/fin_research/collector.yaml b/projects/fin_research/collector.yaml index fde543744..7fd1967a1 100644 --- a/projects/fin_research/collector.yaml +++ b/projects/fin_research/collector.yaml @@ -132,10 +132,12 @@ tools: network_enabled: True # Enable network access in sandbox (default: false for security) tools_config: notebook_executor: {} + shell_executor: {} + file_operation: {} exclude: - python_executor - - shell_executor - - file_operation + - reset_executor + - get_executor_info financial_data_fetcher: mcp: false source_type: hybrid @@ -149,7 +151,10 @@ tools: mcp: false include: - write_file - - delete_file_or_dir + - read_file + - edit_file + - grep + - glob handler: time_handler diff --git a/projects/fin_research/searcher.py b/projects/fin_research/searcher.py index 23703fa64..543227ad1 100644 --- a/projects/fin_research/searcher.py +++ b/projects/fin_research/searcher.py @@ -1,8 +1,9 @@ +import json import os +from callbacks.file_parser import extract_code_blocks +from omegaconf import DictConfig from typing import List, Union -import json -from callbacks.file_parser import extract_code_blocks from ms_agent.agent.code_agent import CodeAgent from ms_agent.llm import Message from ms_agent.llm.openai import OpenAIChat @@ -10,7 +11,6 @@ from ms_agent.utils import get_logger from ms_agent.workflow.deep_research.research_workflow_beta import \ ResearchWorkflowBeta -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/fin_research/time_handler.py b/projects/fin_research/time_handler.py index dd01bc8ae..ef9d8af8b 100644 --- a/projects/fin_research/time_handler.py +++ b/projects/fin_research/time_handler.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from datetime import datetime +from omegaconf import DictConfig from typing import Any from ms_agent.config.config import ConfigLifecycleHandler -from omegaconf import DictConfig class TimeHandler(ConfigLifecycleHandler): diff --git a/projects/fin_research/tools/principle_skill.py b/projects/fin_research/tools/principle_skill.py index 19882f497..8c228de5f 100644 --- a/projects/fin_research/tools/principle_skill.py +++ b/projects/fin_research/tools/principle_skill.py @@ -1,9 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # flake8: noqa +import json import os from typing import Any, Dict, List, Optional, Tuple -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger diff --git a/projects/fin_research/tools/spec_loader.py b/projects/fin_research/tools/spec_loader.py index 6ed8e1e21..d2e5b803e 100644 --- a/projects/fin_research/tools/spec_loader.py +++ b/projects/fin_research/tools/spec_loader.py @@ -1,14 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # flake8: noqa +import json import os +from spec_constant import (PRINCIPLE_ROUTING_GUIDE, PRINCIPLE_SPEC_GUIDE, + WRITING_ROUTING_GUIDE, WRITING_SPEC_GUIDE) from typing import Any, Dict, List, Tuple -import json from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import get_logger -from spec_constant import (PRINCIPLE_ROUTING_GUIDE, PRINCIPLE_SPEC_GUIDE, - WRITING_ROUTING_GUIDE, WRITING_SPEC_GUIDE) logger = get_logger() diff --git a/projects/singularity_cinema/agent.yaml b/projects/singularity_cinema/agent.yaml index dc1756486..14fe6809b 100644 --- a/projects/singularity_cinema/agent.yaml +++ b/projects/singularity_cinema/agent.yaml @@ -280,10 +280,6 @@ tools: allow_read_all_files: true exclude: - edit_file - - list_files - - search_file_content - - search_file_name - - replace_file_lines memory: diversity: diff --git a/projects/singularity_cinema/compose_video/agent.py b/projects/singularity_cinema/compose_video/agent.py index 09995c629..1fd2b0e95 100644 --- a/projects/singularity_cinema/compose_video/agent.py +++ b/projects/singularity_cinema/compose_video/agent.py @@ -1,18 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import math +import moviepy as mp import os import shutil from copy import deepcopy - -import json -import moviepy as mp from moviepy import AudioClip +from omegaconf import DictConfig +from PIL import Image + from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/create_background/agent.py b/projects/singularity_cinema/create_background/agent.py index 44e510b13..fc1be2610 100644 --- a/projects/singularity_cinema/create_background/agent.py +++ b/projects/singularity_cinema/create_background/agent.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import matplotlib.font_manager as fm import os import textwrap +from omegaconf import DictConfig +from PIL import Image, ImageDraw, ImageFont -import matplotlib.font_manager as fm from ms_agent.agent import CodeAgent from ms_agent.llm import LLM from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image, ImageDraw, ImageFont logger = get_logger() diff --git a/projects/singularity_cinema/generate_animation/agent.py b/projects/singularity_cinema/generate_animation/agent.py index 26837050b..780c5b925 100644 --- a/projects/singularity_cinema/generate_animation/agent.py +++ b/projects/singularity_cinema/generate_animation/agent.py @@ -2,9 +2,9 @@ import importlib.util import os import sys +from omegaconf import DictConfig from ms_agent.agent import CodeAgent -from omegaconf import DictConfig class GenerateAnimation(CodeAgent): diff --git a/projects/singularity_cinema/generate_animation/generate_manim_code.py b/projects/singularity_cinema/generate_animation/generate_manim_code.py index 990bab72d..e332959e3 100644 --- a/projects/singularity_cinema/generate_animation/generate_manim_code.py +++ b/projects/singularity_cinema/generate_animation/generate_manim_code.py @@ -1,14 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from concurrent.futures import ThreadPoolExecutor, as_completed +from omegaconf import DictConfig +from PIL import Image from typing import List, Union -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/generate_animation/generate_remotion_code.py b/projects/singularity_cinema/generate_animation/generate_remotion_code.py index 1007a60b1..4d54585b7 100644 --- a/projects/singularity_cinema/generate_animation/generate_remotion_code.py +++ b/projects/singularity_cinema/generate_animation/generate_remotion_code.py @@ -1,16 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import glob +import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed +from omegaconf import DictConfig +from PIL import Image from typing import List, Union -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/generate_audio/agent.py b/projects/singularity_cinema/generate_audio/agent.py index 499779c73..e296bd4e9 100644 --- a/projects/singularity_cinema/generate_audio/agent.py +++ b/projects/singularity_cinema/generate_audio/agent.py @@ -1,21 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import asyncio +import edge_tts +import json +import numpy as np import os import shutil from copy import deepcopy from dataclasses import dataclass, field +from moviepy import AudioClip, AudioFileClip +from omegaconf import DictConfig from typing import List -import edge_tts -import json -import numpy as np -from moviepy import AudioClip, AudioFileClip from ms_agent.agent import CodeAgent from ms_agent.llm import LLM from ms_agent.llm.openai_llm import OpenAI from ms_agent.tools.audio_generator import AudioGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/generate_illustration_prompts/agent.py b/projects/singularity_cinema/generate_illustration_prompts/agent.py index 5529dc67a..b7c774aba 100644 --- a/projects/singularity_cinema/generate_illustration_prompts/agent.py +++ b/projects/singularity_cinema/generate_illustration_prompts/agent.py @@ -1,16 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field +from omegaconf import DictConfig from typing import List, Optional, Union -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/generate_images/agent.py b/projects/singularity_cinema/generate_images/agent.py index e21eb325d..3e5efe03d 100644 --- a/projects/singularity_cinema/generate_images/agent.py +++ b/projects/singularity_cinema/generate_images/agent.py @@ -1,22 +1,22 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import aiohttp import asyncio +import json +import numpy as np import os import re import shutil from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from io import BytesIO +from omegaconf import DictConfig +from PIL import Image from typing import List, Union -import aiohttp -import json -import numpy as np from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.tools.image_generator import ImageGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/generate_script/agent.py b/projects/singularity_cinema/generate_script/agent.py index fc0a726cc..0586aafc2 100644 --- a/projects/singularity_cinema/generate_script/agent.py +++ b/projects/singularity_cinema/generate_script/agent.py @@ -1,12 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from copy import deepcopy +from omegaconf import DictConfig from typing import List from ms_agent import LLMAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/generate_subtitle/agent.py b/projects/singularity_cinema/generate_subtitle/agent.py index efa873851..d956aa2c2 100644 --- a/projects/singularity_cinema/generate_subtitle/agent.py +++ b/projects/singularity_cinema/generate_subtitle/agent.py @@ -1,16 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json +import matplotlib.font_manager as fm import os import re +from omegaconf import DictConfig +from PIL import Image, ImageDraw, ImageFont from typing import List -import json -import matplotlib.font_manager as fm from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image, ImageDraw, ImageFont logger = get_logger() diff --git a/projects/singularity_cinema/generate_video/agent.py b/projects/singularity_cinema/generate_video/agent.py index aaecd7d5f..2393bb98e 100644 --- a/projects/singularity_cinema/generate_video/agent.py +++ b/projects/singularity_cinema/generate_video/agent.py @@ -1,18 +1,18 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import aiohttp import asyncio +import json import os import shutil from concurrent.futures import ThreadPoolExecutor from copy import deepcopy +from omegaconf import DictConfig from typing import List, Union -import aiohttp -import json from ms_agent.agent import CodeAgent from ms_agent.llm import Message from ms_agent.tools.video_generator import VideoGenerator from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/generate_video_prompts/agent.py b/projects/singularity_cinema/generate_video_prompts/agent.py index 2cd091f3b..3f70599ee 100644 --- a/projects/singularity_cinema/generate_video_prompts/agent.py +++ b/projects/singularity_cinema/generate_video_prompts/agent.py @@ -1,13 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from concurrent.futures import ThreadPoolExecutor, as_completed +from omegaconf import DictConfig from typing import List, Union -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/projects/singularity_cinema/parse_images/agent.py b/projects/singularity_cinema/parse_images/agent.py index 138803dc0..c42ce46dc 100644 --- a/projects/singularity_cinema/parse_images/agent.py +++ b/projects/singularity_cinema/parse_images/agent.py @@ -1,19 +1,19 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 import hashlib +import json import os import re from concurrent.futures import ThreadPoolExecutor from copy import deepcopy +from omegaconf import DictConfig +from PIL import Image from urllib.request import urlretrieve -import json from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.llm.openai_llm import OpenAI from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/render_animation/agent.py b/projects/singularity_cinema/render_animation/agent.py index 1b3007c66..e2d2bee86 100644 --- a/projects/singularity_cinema/render_animation/agent.py +++ b/projects/singularity_cinema/render_animation/agent.py @@ -2,9 +2,9 @@ import os import sys +from omegaconf import DictConfig from ms_agent.agent import CodeAgent -from omegaconf import DictConfig class RenderAnimation(CodeAgent): diff --git a/projects/singularity_cinema/render_animation/render_manim.py b/projects/singularity_cinema/render_animation/render_manim.py index 19b9abc27..17f62aa4a 100644 --- a/projects/singularity_cinema/render_animation/render_manim.py +++ b/projects/singularity_cinema/render_animation/render_manim.py @@ -1,21 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import base64 +import json import os import re import shutil import subprocess from concurrent.futures import ThreadPoolExecutor, as_completed from copy import deepcopy +from moviepy import VideoFileClip +from omegaconf import DictConfig from os import getcwd +from PIL import Image from typing import List, Union -import json -from moviepy import VideoFileClip from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig -from PIL import Image logger = get_logger() diff --git a/projects/singularity_cinema/render_animation/render_remotion.py b/projects/singularity_cinema/render_animation/render_remotion.py index e58e2927e..09426b91f 100644 --- a/projects/singularity_cinema/render_animation/render_remotion.py +++ b/projects/singularity_cinema/render_animation/render_remotion.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os import re import shutil @@ -6,14 +7,13 @@ import urllib.request import zipfile from collections import defaultdict +from moviepy import VideoFileClip +from omegaconf import DictConfig from typing import List, Optional, Tuple, Union -import json -from moviepy import VideoFileClip from ms_agent.agent import CodeAgent from ms_agent.llm import LLM, Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() @@ -718,8 +718,8 @@ def _check_edge_clipping(frame_path, threshold=10): Returns True if clipping detected (colored pixels at edges). """ try: - from PIL import Image import numpy as np + from PIL import Image img = Image.open(frame_path).convert('RGB') pixels = np.array(img) diff --git a/projects/singularity_cinema/segment/agent.py b/projects/singularity_cinema/segment/agent.py index 35df82260..f35718450 100644 --- a/projects/singularity_cinema/segment/agent.py +++ b/projects/singularity_cinema/segment/agent.py @@ -1,12 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json import os from copy import deepcopy +from omegaconf import DictConfig -import json from ms_agent.agent import LLMAgent from ms_agent.llm import Message from ms_agent.utils import get_logger -from omegaconf import DictConfig logger = get_logger() diff --git a/requirements/research.txt b/requirements/research.txt index 67ce2d3fd..693301a94 100644 --- a/requirements/research.txt +++ b/requirements/research.txt @@ -14,3 +14,4 @@ Pillow python-dotenv requests rich +socksio diff --git a/setup.py b/setup.py index 1bded7e04..aec42aa2f 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # !/usr/bin/env python -import os -import shutil from setuptools import find_packages, setup from setuptools.command.build_py import build_py as _build_py + +import os +import shutil from typing import List diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py index 5a4f43213..705c7184e 100644 --- a/tests/knowledge_search/test_sirschmunk.py +++ b/tests/knowledge_search/test_sirschmunk.py @@ -1,23 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Tests for SirchmunkSearch knowledge search integration via LLMAgent. +"""Tests for SirchmunkSearch and localsearch tool integration. -These tests verify the sirchmunk-based knowledge search functionality -through the LLMAgent entry point, including verification that -search_result and searching_detail fields are properly populated. - -To run these tests, you need to set the following environment variables: - - TEST_LLM_API_KEY: Your LLM API key - - TEST_LLM_BASE_URL: Your LLM API base URL (optional, default: OpenAI) - - TEST_LLM_MODEL_NAME: Your LLM model name (optional) - - TEST_EMBEDDING_MODEL_ID: Embedding model ID (optional) - - TEST_EMBEDDING_MODEL_CACHE_DIR: Embedding model cache directory (optional) - -Example: +Example (full sirchmunk run): export TEST_LLM_API_KEY="your-api-key" - export TEST_LLM_BASE_URL="https://api.openai.com/v1" - export TEST_LLM_MODEL_NAME="gpt-4o" - export TEST_EMBEDDING_MODEL_ID="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" - export TEST_EMBEDDING_MODEL_CACHE_DIR="/tmp/embedding_cache" python -m pytest tests/knowledge_search/test_sirschmunk.py """ import asyncio @@ -26,92 +11,41 @@ import unittest from pathlib import Path -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.agent import LLMAgent -from ms_agent.config import Config +from ms_agent.tools.search.sirchmunk_search import SirchmunkSearch +from ms_agent.llm.utils import Message +from ms_agent.tools.tool_manager import ToolManager from omegaconf import DictConfig -from modelscope.utils.test_utils import test_level - - -class SirchmunkLLMAgentIntegrationTest(unittest.TestCase): - """Test cases for SirchmunkSearch integration with LLMAgent. - - These tests verify that when LLMAgent runs a query that triggers - knowledge search, the Message objects have search_result and - searching_detail fields properly populated. - """ +class SirchmunkKnowledgeSearchTest(unittest.TestCase): + """Sirchmunk config, ToolManager registration""" @classmethod def setUpClass(cls): - """Set up test fixtures.""" - # Create test directory with sample files cls.test_dir = Path('./test_llm_agent_knowledge') cls.test_dir.mkdir(exist_ok=True) - - # Create sample documentation - (cls.test_dir / 'README.md').write_text(''' -# Test Project Documentation - -## Overview -This is a test project for knowledge search integration. - -## API Reference - -### UserManager -The UserManager class handles user operations: -- create_user: Create a new user account -- delete_user: Delete an existing user -- update_user: Update user information -- get_user: Retrieve user details - -### AuthService -The AuthService class handles authentication: -- login: Authenticate user credentials -- logout: End user session -- refresh_token: Refresh authentication token -- verify_token: Validate authentication token -''') - - (cls.test_dir / 'config.py').write_text(''' -"""Configuration module.""" - -class Config: - """Application configuration.""" - - def __init__(self): - self.database_url = "postgresql://localhost:5432/mydb" - self.secret_key = "your-secret-key" - self.debug_mode = False - - def load_from_env(self): - """Load configuration from environment variables.""" - import os - self.database_url = os.getenv("DATABASE_URL", self.database_url) - self.secret_key = os.getenv("SECRET_KEY", self.secret_key) - return self -''') + (cls.test_dir / 'README.md').write_text( + '# Demo\n\nUserManager.create_user creates a user.\n') @classmethod def tearDownClass(cls): - """Clean up test fixtures.""" if cls.test_dir.exists(): shutil.rmtree(cls.test_dir, ignore_errors=True) work_dir = Path('./.sirchmunk') if work_dir.exists(): shutil.rmtree(work_dir, ignore_errors=True) - def _get_agent_config(self): - """Create agent configuration with knowledge search.""" + def _base_config(self) -> DictConfig: llm_api_key = os.getenv('TEST_LLM_API_KEY', 'test-api-key') - llm_base_url = os.getenv('TEST_LLM_BASE_URL', 'https://api.openai.com/v1') + llm_base_url = os.getenv('TEST_LLM_BASE_URL', + 'https://api.openai.com/v1') llm_model_name = os.getenv('TEST_LLM_MODEL_NAME', 'gpt-4o-mini') - # Read from TEST_* env vars (for test-specific config) - # These can be set from .env file which uses TEST_* prefix embedding_model_id = os.getenv('TEST_EMBEDDING_MODEL_ID', '') - embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', '') - - config = DictConfig({ + embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', + '') + return DictConfig({ + 'output_dir': + './outputs_knowledge_test', 'llm': { 'service': 'openai', 'model': llm_model_name, @@ -122,81 +56,66 @@ def _get_agent_config(self): 'temperature': 0.3, 'max_tokens': 500, }, - 'knowledge_search': { - 'name': 'SirchmunkSearch', - 'paths': [str(self.test_dir)], - 'work_path': './.sirchmunk', - 'llm_api_key': llm_api_key, - 'llm_base_url': llm_base_url, - 'llm_model_name': llm_model_name, - 'embedding_model': embedding_model_id, - 'embedding_model_cache_dir': embedding_model_cache_dir, - 'mode': 'FAST', - } + 'tools': { + 'localsearch': { + 'paths': [str(self.test_dir)], + 'work_path': './.sirchmunk', + 'llm_api_key': llm_api_key, + 'llm_base_url': llm_base_url, + 'llm_model_name': llm_model_name, + 'embedding_model': embedding_model_id, + 'embedding_model_cache_dir': embedding_model_cache_dir, + 'mode': 'FAST', + }, + }, }) - return config - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_llm_agent_with_knowledge_search(self): - """Test LLMAgent using knowledge search. - - This test verifies that: - 1. LLMAgent can be initialized with SirchmunkSearch configuration - 2. Running a query produces a valid response - 3. User message has searching_detail and search_result populated - 4. searching_detail contains expected keys (logs, mode, paths) - 5. search_result is a list - """ - config = self._get_agent_config() + def test_does_not_inject_knowledge_search(self): + """Local sirchmunk search is no longer merged into the user message here.""" + config = self._base_config() agent = LLMAgent(config=config, tag='test-knowledge-agent') - - # Test query that should trigger knowledge search - query = 'How do I use UserManager to create a user?' - - async def run_agent(): - result = await agent.run(query) - return result - - result = asyncio.run(run_agent()) - - # Verify result - self.assertIsNotNone(result) - self.assertIsInstance(result, list) - self.assertTrue(len(result) > 0) - - # Check that assistant message exists - assistant_message = [m for m in result if m.role == 'assistant'] - self.assertTrue(len(assistant_message) > 0) - - # Check that user message has search_result and searching_detail populated - user_messages = [m for m in result if m.role == 'user'] - self.assertTrue(len(user_messages) > 0, "Expected at least one user message") - - # The first user message should have search details after do_rag processing - user_msg = user_messages[0] - self.assertTrue( - hasattr(user_msg, 'searching_detail'), - "User message should have searching_detail attribute" - ) + original = 'How do I use UserManager?' + + async def run(): + messages = [ + Message(role='system', content='You are a helper.'), + Message(role='user', content=original), + ] + await agent.run(messages) + return messages + + messages = asyncio.run(run()) + print(f'messages: {messages}') + + def test_tool_manager_registers_localsearch(self): + """When tools.localsearch.paths is set, ToolManager exposes localsearch.""" + + async def run(): + config = self._base_config() + tm = ToolManager(config, trust_remote_code=False) + await tm.connect() + tools = await tm.get_tools() + await tm.cleanup() + return tools + + tools = asyncio.run(run()) + names = [t['tool_name'] for t in tools] self.assertTrue( - hasattr(user_msg, 'search_result'), - "User message should have search_result attribute" + any(n.endswith('localsearch') for n in names), + f'Expected localsearch in tools, got: {names}', ) - # Check that searching_detail is a dict with expected keys - self.assertIsInstance( - user_msg.searching_detail, dict, - "searching_detail should be a dictionary" - ) - self.assertIn('logs', user_msg.searching_detail) - self.assertIn('mode', user_msg.searching_detail) - self.assertIn('paths', user_msg.searching_detail) - - # Check that search_result is a list (may be empty if no relevant docs found) - self.assertIsInstance( - user_msg.search_result, list, - "search_result should be a list" - ) + @unittest.skipUnless( + os.getenv('TEST_SIRCHMUNK_SMOKE', ''), + 'Set TEST_SIRCHMUNK_SMOKE=1 to run sirchmunk API smoke test', + ) + def test_sirchmunk_search_query_smoke(self): + """Optional: run sirchmunk once (needs network / valid API keys).""" + config = self._base_config() + searcher = SirchmunkSearch(config) + result = asyncio.run(searcher.query('UserManager')) + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) if __name__ == '__main__': diff --git a/tests/utils/test_filesystem_tool_config.py b/tests/utils/test_filesystem_tool_config.py new file mode 100644 index 000000000..5bc8310c9 --- /dev/null +++ b/tests/utils/test_filesystem_tool_config.py @@ -0,0 +1,54 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""FileSystemTool config: include aliases and grep/glob registration.""" + +import asyncio +import tempfile + +from omegaconf import OmegaConf + +from ms_agent.tools.filesystem_tool import FileSystemTool + + +def test_include_short_aliases_expand_to_canonical_names(): + async def _run(): + with tempfile.TemporaryDirectory() as td: + cfg = OmegaConf.create({ + 'output_dir': td, + 'tools': { + 'file_system': { + 'mcp': False, + 'include': ['read', 'write', 'glob'], + }, + }, + }) + fs = FileSystemTool(cfg) + tools = await fs.get_tools() + names = [t['tool_name'] for t in tools['file_system']] + assert 'read_file' in names + assert 'write_file' in names + assert 'glob' in names + assert 'grep' not in names + assert 'read' not in names + assert 'write' not in names + + asyncio.run(_run()) + + +def test_grep_glob_listed_with_full_names(): + async def _run(): + with tempfile.TemporaryDirectory() as td: + cfg = OmegaConf.create({ + 'output_dir': td, + 'tools': { + 'file_system': { + 'mcp': False, + 'include': ['grep', 'glob'], + }, + }, + }) + fs = FileSystemTool(cfg) + tools = await fs.get_tools() + names = [t['tool_name'] for t in tools['file_system']] + assert names == ['grep', 'glob'] + + asyncio.run(_run()) diff --git a/tests/utils/test_snapshot_smoke.py b/tests/utils/test_snapshot_smoke.py new file mode 100644 index 000000000..17bbc2e76 --- /dev/null +++ b/tests/utils/test_snapshot_smoke.py @@ -0,0 +1,331 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Smoke tests for the snapshot utility and LLMAgent rollback interface. + +No network, no LLM — all tests run fully offline using tempfile directories. +""" +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from ms_agent.utils.snapshot import ( + list_snapshots, + restore_snapshot, + take_snapshot, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _write(path: str, content: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w', encoding='utf-8') as f: + f.write(content) + + +def _read(path: str) -> str: + with open(path, 'r', encoding='utf-8') as f: + return f.read() + + +# --------------------------------------------------------------------------- +# snapshot utility tests +# --------------------------------------------------------------------------- + +class TestTakeSnapshot(unittest.TestCase): + + def test_empty_dir_returns_none(self): + """Nothing to commit → None.""" + with tempfile.TemporaryDirectory() as td: + result = take_snapshot(td, 'empty') + self.assertIsNone(result) + + def test_new_file_returns_hash(self): + """A new file produces a commit hash.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'hello.txt'), 'hello') + h = take_snapshot(td, 'add hello.txt', message_count=2) + self.assertIsNotNone(h) + self.assertIsInstance(h, str) + self.assertGreater(len(h), 0) + + def test_no_change_after_snapshot_returns_none(self): + """Second snapshot with no changes → None.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + take_snapshot(td, 'first') + result = take_snapshot(td, 'second — no change') + self.assertIsNone(result) + + def test_message_truncated_to_120_chars(self): + """Long messages are truncated in the commit subject.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'x') + h = take_snapshot(td, 'A' * 200) + self.assertIsNotNone(h) + snaps = list_snapshots(td) + self.assertEqual(len(snaps[0]['message']), 120) + + def test_snapshot_dir_not_tracked(self): + """The .ms_agent_snapshots dir itself must not appear in git status.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + take_snapshot(td, 'first') + # After snapshot, no pending changes (snapshot dir excluded) + result = take_snapshot(td, 'should be nothing') + self.assertIsNone(result) + + +class TestListSnapshots(unittest.TestCase): + + def test_no_repo_returns_empty(self): + with tempfile.TemporaryDirectory() as td: + self.assertEqual(list_snapshots(td), []) + + def test_returns_snapshots_most_recent_first(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'a.txt'), 'v1') + h1 = take_snapshot(td, 'first snap', message_count=1) + _write(os.path.join(td, 'a.txt'), 'v2') + h2 = take_snapshot(td, 'second snap', message_count=3) + + snaps = list_snapshots(td) + self.assertEqual(len(snaps), 2) + # Most recent first + self.assertEqual(snaps[0]['hash'], h2) + self.assertEqual(snaps[1]['hash'], h1) + + def test_snapshot_fields(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'b.txt'), 'hello') + h = take_snapshot(td, 'my message', message_count=5) + snaps = list_snapshots(td) + self.assertEqual(len(snaps), 1) + s = snaps[0] + self.assertEqual(s['hash'], h) + self.assertEqual(s['message'], 'my message') + self.assertEqual(s['message_count'], 5) + self.assertIn('date', s) + + def test_message_count_default_zero(self): + """message_count defaults to 0 when not passed.""" + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'c.txt'), 'x') + take_snapshot(td, 'no count') + snaps = list_snapshots(td) + self.assertEqual(snaps[0]['message_count'], 0) + + +class TestRestoreSnapshot(unittest.TestCase): + + def test_no_repo_returns_false(self): + with tempfile.TemporaryDirectory() as td: + ok, mc = restore_snapshot(td, 'abc1234') + self.assertFalse(ok) + self.assertEqual(mc, 0) + + def test_restore_reverts_file_content(self): + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'data.txt') + _write(path, 'original') + h1 = take_snapshot(td, 'original state', message_count=2) + + _write(path, 'modified') + take_snapshot(td, 'modified state', message_count=4) + + self.assertEqual(_read(path), 'modified') + + ok, mc = restore_snapshot(td, h1) + self.assertTrue(ok) + self.assertEqual(mc, 2) + self.assertEqual(_read(path), 'original') + + def test_restore_returns_message_count(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'x.txt'), 'a') + h = take_snapshot(td, 'snap', message_count=7) + _write(os.path.join(td, 'x.txt'), 'b') + take_snapshot(td, 'snap2', message_count=9) + + ok, mc = restore_snapshot(td, h) + self.assertTrue(ok) + self.assertEqual(mc, 7) + + def test_restore_deleted_file(self): + """A file deleted after snapshot is recreated on restore.""" + with tempfile.TemporaryDirectory() as td: + path = os.path.join(td, 'will_be_deleted.txt') + _write(path, 'keep me') + h = take_snapshot(td, 'before delete', message_count=1) + + os.remove(path) + take_snapshot(td, 'after delete', message_count=2) + self.assertFalse(os.path.exists(path)) + + ok, _ = restore_snapshot(td, h) + self.assertTrue(ok) + self.assertTrue(os.path.exists(path)) + self.assertEqual(_read(path), 'keep me') + + +# --------------------------------------------------------------------------- +# LLMAgent interface tests +# --------------------------------------------------------------------------- + +class TestLLMAgentSnapshotInterface(unittest.TestCase): + """ + Tests for LLMAgent.list_snapshots() and LLMAgent.rollback(). + The LLM and tool_manager are not initialised — we only exercise the + snapshot-related methods which don't require them. + """ + + def _make_agent(self, output_dir: str): + from omegaconf import OmegaConf + from ms_agent.agent.llm_agent import LLMAgent + cfg = OmegaConf.create({ + 'llm': {'model': 'fake', 'api_key': 'fake', 'model_server': 'openai'}, + 'output_dir': output_dir, + }) + agent = LLMAgent(cfg, tag='smoke-test') + return agent + + def test_list_snapshots_empty(self): + with tempfile.TemporaryDirectory() as td: + agent = self._make_agent(td) + self.assertEqual(agent.list_snapshots(), []) + + def test_list_snapshots_delegates_to_utility(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + take_snapshot(td, 'snap', message_count=3) + + agent = self._make_agent(td) + snaps = agent.list_snapshots() + self.assertEqual(len(snaps), 1) + self.assertEqual(snaps[0]['message_count'], 3) + + def test_rollback_restores_files_and_truncates_history(self): + from omegaconf import OmegaConf + from ms_agent.agent.llm_agent import LLMAgent + from ms_agent.llm.utils import Message + from ms_agent.utils import save_history + + with tempfile.TemporaryDirectory() as td: + agent = self._make_agent(td) + + # Write a file and take a snapshot with message_count=2 + path = os.path.join(td, 'work.txt') + _write(path, 'v1') + h1 = take_snapshot(td, '[pre] first task', message_count=2) + + # Save 4 messages to history + messages = [ + Message(role='system', content='sys'), + Message(role='user', content='task1'), + Message(role='assistant', content='done1'), + Message(role='user', content='task2'), + ] + save_history(td, 'smoke-test', agent.config, messages) + + # Modify the file and take a second snapshot + _write(path, 'v2') + take_snapshot(td, '[pre] second task', message_count=4) + + self.assertEqual(_read(path), 'v2') + + # Rollback to h1 + ok = agent.rollback(h1) + self.assertTrue(ok) + + # File should be restored + self.assertEqual(_read(path), 'v1') + + # History should be truncated to message_count=2 + from ms_agent.utils import read_history + _, saved = read_history(td, 'smoke-test') + self.assertIsNotNone(saved) + self.assertEqual(len(saved), 2) + + def test_rollback_clears_read_cache(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + h = take_snapshot(td, 'snap', message_count=1) + _write(os.path.join(td, 'f.txt'), 'v2') + take_snapshot(td, 'snap2', message_count=2) + + agent = self._make_agent(td) + + # Attach a fake tool with _read_cache + fake_tool = MagicMock() + fake_tool._read_cache = {'some/path': {'mtime': 123}} + fake_manager = MagicMock() + fake_manager.extra_tools = [fake_tool] + agent.tool_manager = fake_manager + + agent.rollback(h) + self.assertEqual(fake_tool._read_cache, {}) + + def test_rollback_invalid_hash_returns_false(self): + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'f.txt'), 'v1') + take_snapshot(td, 'snap') + + agent = self._make_agent(td) + ok = agent.rollback('deadbeef') + self.assertFalse(ok) + + def test_on_task_begin_auto_snapshots(self): + """on_task_begin should take a snapshot automatically — no explicit call needed.""" + import asyncio + from ms_agent.llm.utils import Message + + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'work.txt'), 'v1') + agent = self._make_agent(td) + + messages = [ + Message(role='system', content='sys'), + Message(role='user', content='do something useful'), + ] + + # No explicit take_snapshot call — on_task_begin should do it + asyncio.run(agent.on_task_begin(messages)) + + snaps = list_snapshots(td) + self.assertEqual(len(snaps), 1) + self.assertIn('do something useful', snaps[0]['message']) + self.assertEqual(snaps[0]['message_count'], len(messages)) + + def test_on_task_begin_no_snapshot_when_disabled(self): + """enable_snapshots=False suppresses automatic snapshot.""" + import asyncio + from omegaconf import OmegaConf + from ms_agent.agent.llm_agent import LLMAgent + from ms_agent.llm.utils import Message + + with tempfile.TemporaryDirectory() as td: + _write(os.path.join(td, 'work.txt'), 'v1') + cfg = OmegaConf.create({ + 'llm': {'model': 'fake', 'api_key': 'fake', 'model_server': 'openai'}, + 'output_dir': td, + 'enable_snapshots': False, + }) + agent = LLMAgent(cfg, tag='smoke-test') + messages = [ + Message(role='system', content='sys'), + Message(role='user', content='task'), + ] + asyncio.run(agent.on_task_begin(messages)) + self.assertEqual(list_snapshots(td), []) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_task_manager_smoke.py b/tests/utils/test_task_manager_smoke.py new file mode 100644 index 000000000..f29f57321 --- /dev/null +++ b/tests/utils/test_task_manager_smoke.py @@ -0,0 +1,363 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Smoke tests for TaskManager and AgentTool dynamic/background mode. +No network, no LLM — all tests run fully offline. +""" +import asyncio +import os +import sys +import unittest + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from ms_agent.utils.task_manager import BackgroundTask, TaskManager + + +# --------------------------------------------------------------------------- +# TaskManager unit tests +# --------------------------------------------------------------------------- + +class TestTaskManager(unittest.IsolatedAsyncioTestCase): + + async def test_register_and_complete(self): + tm = TaskManager() + task_id = tm.register('agent', 'my_tool', 'do something') + self.assertIn(task_id, tm._tasks) + self.assertEqual(tm._tasks[task_id].status, 'running') + + await tm.complete(task_id, 'great result') + self.assertEqual(tm._tasks[task_id].status, 'completed') + self.assertEqual(tm._tasks[task_id].result, 'great result') + + notifications = tm.drain_notifications() + self.assertEqual(len(notifications), 1) + self.assertIn('completed', notifications[0]) + self.assertIn('great result', notifications[0]) + self.assertIn(task_id, notifications[0]) + + async def test_register_and_fail(self): + tm = TaskManager() + task_id = tm.register('agent', 'my_tool', 'do something') + await tm.fail(task_id, 'something went wrong') + self.assertEqual(tm._tasks[task_id].status, 'failed') + + notifications = tm.drain_notifications() + self.assertEqual(len(notifications), 1) + self.assertIn('failed', notifications[0]) + self.assertIn('something went wrong', notifications[0]) + + def test_kill(self): + tm = TaskManager() + task_id = tm.register('agent', 'my_tool', 'do something') + tm.kill(task_id) + self.assertEqual(tm._tasks[task_id].status, 'killed') + # kill again is a no-op + tm.kill(task_id) + self.assertEqual(tm._tasks[task_id].status, 'killed') + + def test_kill_all(self): + tm = TaskManager() + ids = [tm.register('agent', 'tool', f'task {i}') for i in range(3)] + tm.kill_all() + for tid in ids: + self.assertEqual(tm._tasks[tid].status, 'killed') + + def test_drain_empty(self): + tm = TaskManager() + self.assertEqual(tm.drain_notifications(), []) + + async def test_drain_multiple(self): + tm = TaskManager() + id1 = tm.register('agent', 'tool_a', 'task a') + id2 = tm.register('agent', 'tool_b', 'task b') + await tm.complete(id1, 'result a') + await tm.fail(id2, 'error b') + notifications = tm.drain_notifications() + self.assertEqual(len(notifications), 2) + # drain again should be empty + self.assertEqual(tm.drain_notifications(), []) + + def test_get_task(self): + tm = TaskManager() + task_id = tm.register('shell', 'bash', 'run script') + task = tm.get_task(task_id) + self.assertIsNotNone(task) + self.assertEqual(task.task_type, 'shell') + self.assertIsNone(tm.get_task('nonexistent')) + + def test_running_tasks(self): + tm = TaskManager() + id1 = tm.register('agent', 'tool', 'task 1') + id2 = tm.register('agent', 'tool', 'task 2') + tm.kill(id2) + running = tm.running_tasks() + self.assertEqual(len(running), 1) + self.assertEqual(running[0].task_id, id1) + + async def test_notification_xml_structure(self): + tm = TaskManager() + task_id = tm.register('agent', 'searcher_tool', 'search for X') + await tm.complete(task_id, 'found Y') + notif = tm.drain_notifications()[0] + self.assertTrue(notif.startswith('')) + self.assertTrue(notif.strip().endswith('')) + self.assertIn('', notif) + self.assertIn('agent', notif) + self.assertIn('searcher_tool', notif) + self.assertIn('search for X', notif) + self.assertIn('completed', notif) + self.assertIn('found Y', notif) + self.assertIn('', notif) + + +# --------------------------------------------------------------------------- +# AgentTool dynamic spec (merged SplitTask) — schema validation only +# --------------------------------------------------------------------------- + +class TestAgentToolDynamicSpec(unittest.TestCase): + + def _make_config(self): + from omegaconf import OmegaConf + return OmegaConf.create({ + 'tag': 'test-agent', + 'output_dir': '/tmp/test_agent_tool', + 'tools': { + 'split_task': { + 'tag_prefix': 'worker-', + 'run_in_thread': False, + 'run_in_process': False, + } + } + }) + + def test_split_task_spec_registered(self): + from ms_agent.tools.agent_tool import AgentTool + config = self._make_config() + tool = AgentTool(config) + self.assertTrue(tool.enabled) + self.assertIn('split_to_sub_task', tool._specs) + + def test_split_task_spec_is_dynamic(self): + from ms_agent.tools.agent_tool import AgentTool + config = self._make_config() + tool = AgentTool(config) + spec = tool._specs['split_to_sub_task'] + self.assertTrue(spec.dynamic) + self.assertFalse(spec.run_in_process) + + def test_split_task_parameters_schema(self): + from ms_agent.tools.agent_tool import AgentTool + config = self._make_config() + tool = AgentTool(config) + spec = tool._specs['split_to_sub_task'] + props = spec.parameters['properties'] + self.assertIn('tasks', props) + self.assertIn('execution_mode', props) + # execution_mode must have enum + self.assertIn('enum', props['execution_mode']) + self.assertIn('parallel', props['execution_mode']['enum']) + self.assertIn('sequential', props['execution_mode']['enum']) + + def test_dynamic_mode_in_agent_tools_definitions(self): + from ms_agent.tools.agent_tool import AgentTool + from omegaconf import OmegaConf + config = OmegaConf.create({ + 'tag': 'test-agent', + 'output_dir': '/tmp/test_agent_tool', + 'tools': { + 'agent_tools': { + 'definitions': [{ + 'tool_name': 'my_dynamic_tool', + 'mode': 'dynamic', + 'description': 'A dynamic tool', + }] + } + } + }) + tool = AgentTool(config) + self.assertIn('my_dynamic_tool', tool._specs) + self.assertTrue(tool._specs['my_dynamic_tool'].dynamic) + + +# --------------------------------------------------------------------------- +# TaskControlTool unit tests +# --------------------------------------------------------------------------- + +class TestTaskControlTool(unittest.IsolatedAsyncioTestCase): + + def _make_tool(self): + from ms_agent.tools.task_control_tool import TaskControlTool + from omegaconf import OmegaConf + config = OmegaConf.create({'output_dir': '/tmp'}) + tool = TaskControlTool(config) + tm = TaskManager() + tool.set_task_manager(tm) + return tool, tm + + async def test_list_tasks_empty(self): + tool, _ = self._make_tool() + result = await tool.call_tool('task_control', tool_name='list_tasks', tool_args={}) + self.assertEqual(result, 'No background tasks registered.') + + async def test_list_tasks_with_entries(self): + import json + tool, tm = self._make_tool() + tm.register('agent', 'searcher', 'search X') + result = await tool.call_tool('task_control', tool_name='list_tasks', tool_args={}) + rows = json.loads(result) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]['tool_name'], 'searcher') + self.assertEqual(rows[0]['status'], 'running') + + async def test_cancel_task(self): + tool, tm = self._make_tool() + task_id = tm.register('agent', 'searcher', 'search X') + result = await tool.call_tool('task_control', tool_name='cancel_task', + tool_args={'task_id': task_id}) + self.assertIn('cancelled', result) + self.assertEqual(tm.get_task(task_id).status, 'killed') + + async def test_cancel_nonexistent(self): + tool, _ = self._make_tool() + result = await tool.call_tool('task_control', tool_name='cancel_task', + tool_args={'task_id': 'bad-id'}) + self.assertIn('not found', result) + + async def test_cancel_already_done(self): + tool, tm = self._make_tool() + task_id = tm.register('agent', 'searcher', 'search X') + tm.kill(task_id) + result = await tool.call_tool('task_control', tool_name='cancel_task', + tool_args={'task_id': task_id}) + self.assertIn('already', result) + + +# --------------------------------------------------------------------------- +# AgentTool escape-to-background (sync_timeout_s + escape_to_background API) +# --------------------------------------------------------------------------- + +class TestAgentToolEscape(unittest.IsolatedAsyncioTestCase): + """Tests for _run_sync_escapable and escape_to_background. + + We mock _run_agent and _launch_background so no real sub-agent is needed. + """ + + def _make_agent_tool(self): + from ms_agent.tools.agent_tool import AgentTool, _AgentToolSpec + from ms_agent.utils.task_manager import TaskManager + from omegaconf import OmegaConf + config = OmegaConf.create({ + 'tag': 'test', + 'output_dir': '/tmp', + 'tools': {}, + }) + tool = AgentTool(config) + tm = TaskManager() + tool.set_task_manager(tm) + return tool, tm + + def _make_spec(self, sync_timeout_s=None): + from ms_agent.tools.agent_tool import _AgentToolSpec + return _AgentToolSpec( + tool_name='test_tool', + description='test', + parameters={}, + config_path=None, + inline_config=None, + server_name='test_server', + tag_prefix='t-', + input_mode='text', + request_field='request', + input_template=None, + output_mode='final_message', + max_output_chars=1000, + trust_remote_code=None, + env=None, + run_in_thread=False, + run_in_process=False, + dynamic=False, + sync_timeout_s=sync_timeout_s, + ) + + async def test_normal_completion(self): + """Without timeout, task completes normally.""" + tool, _ = self._make_agent_tool() + spec = self._make_spec() + + async def fake_run_agent(agent, payload, spec, call_id=None): + return ['result'] + + tool._run_agent = fake_run_agent + result = await tool._run_sync_escapable('payload', spec, 'cid1') + self.assertEqual(result, ['result']) + + async def test_escape_to_background_via_api(self): + """escape_to_background() triggers escape before task completes.""" + tool, tm = self._make_agent_tool() + spec = self._make_spec() + + started = asyncio.Event() + + async def slow_run_agent(agent, payload, spec, call_id=None): + started.set() + await asyncio.sleep(10) # will be cancelled + return ['should not reach'] + + launched = {} + + async def fake_launch_background(payload, spec, call_id): + import json + launched['called'] = True + task_id = tm.register('agent', spec.tool_name, 'escaped') + return json.dumps({'status': 'async_launched', 'task_id': task_id}) + + tool._run_agent = slow_run_agent + tool._launch_background = fake_launch_background + + async def _trigger_escape(): + await started.wait() + tool.escape_to_background('cid2') + + result_task = asyncio.create_task( + tool._run_sync_escapable('payload', spec, 'cid2')) + await asyncio.gather(_trigger_escape(), return_exceptions=True) + result = await result_task + + self.assertIsInstance(result, str) + import json + data = json.loads(result) + self.assertEqual(data['status'], 'async_launched') + self.assertTrue(launched.get('called')) + + async def test_escape_to_background_unknown_call_id(self): + """escape_to_background returns False for unknown call_id.""" + tool, _ = self._make_agent_tool() + self.assertFalse(tool.escape_to_background('nonexistent')) + + async def test_sync_timeout_triggers_escape(self): + """sync_timeout_s causes auto-escape when task takes too long.""" + tool, tm = self._make_agent_tool() + spec = self._make_spec(sync_timeout_s=0.05) + + async def slow_run_agent(agent, payload, spec, call_id=None): + await asyncio.sleep(10) + return ['should not reach'] + + async def fake_launch_background(payload, spec, call_id): + import json + task_id = tm.register('agent', spec.tool_name, 'timed out') + return json.dumps({'status': 'async_launched', 'task_id': task_id}) + + tool._run_agent = slow_run_agent + tool._launch_background = fake_launch_background + + result = await tool._run_sync_escapable('payload', spec, 'cid3') + import json + data = json.loads(result) + self.assertEqual(data['status'], 'async_launched') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_workspace_policy.py b/tests/utils/test_workspace_policy.py new file mode 100644 index 000000000..012888533 --- /dev/null +++ b/tests/utils/test_workspace_policy.py @@ -0,0 +1,78 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for WorkspacePolicyKernel.""" + +import tempfile +from pathlib import Path + +import pytest + +from ms_agent.utils.workspace_policy import WorkspacePolicyError, WorkspacePolicyKernel + + +def test_default_root_is_output_dir(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + out.mkdir() + k = WorkspacePolicyKernel(out) + p = k.resolve_under_roots('foo/bar') + assert p == (out / 'foo' / 'bar').resolve() + + +def test_rejects_escape(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + out.mkdir() + k = WorkspacePolicyKernel(out) + with pytest.raises(WorkspacePolicyError): + k.resolve_under_roots('../../etc/passwd') + + +def test_extra_allow_root(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + other = Path(td) / 'other' + out.mkdir() + other.mkdir() + k = WorkspacePolicyKernel(out, extra_allow_roots=[str(other)]) + assert k.resolve_under_roots(str(other / 'x')) == (other / 'x').resolve() + + +def test_read_only_blocks_redirect(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + out.mkdir() + k = WorkspacePolicyKernel( + out, + shell_default_mode='read_only', + ) + with pytest.raises(WorkspacePolicyError): + k.assert_shell_command_allowed('echo x > file.txt') + + +def test_workspace_write_allows_redirect_but_blocks_network(): + with tempfile.TemporaryDirectory() as td: + out = Path(td) / 'out' + out.mkdir() + k = WorkspacePolicyKernel( + out, + shell_default_mode='workspace_write', + shell_network_enabled=False, + ) + k.assert_shell_command_allowed('echo x > file.txt') + with pytest.raises(WorkspacePolicyError): + k.assert_shell_command_allowed('curl https://example.com') + + +def test_artifact_manager_spill(tmp_path): + from ms_agent.utils.artifact_manager import ArtifactManager + + am = ArtifactManager(tmp_path, max_combined_bytes=32) + big = 'a' * 100 + packed = am.pack_text_result( + tool_name='t', + call_id='c1', + stdout=big, + stderr='', + ) + assert packed.get('truncated') is True + assert 'artifact_path' in packed diff --git a/webui/backend/agent_runner.py b/webui/backend/agent_runner.py index 1c122ff27..3f3879ac1 100644 --- a/webui/backend/agent_runner.py +++ b/webui/backend/agent_runner.py @@ -9,11 +9,10 @@ import signal import subprocess import sys +import yaml from datetime import datetime from typing import Any, Callable, Dict, Optional -import yaml - class AgentRunner: """Runs ms-agent as a subprocess with output streaming""" diff --git a/webui/backend/api.py b/webui/backend/api.py index 126ac4c02..fa68b849b 100644 --- a/webui/backend/api.py +++ b/webui/backend/api.py @@ -4,14 +4,13 @@ """ import mimetypes import os -from pathlib import Path -from typing import Any, Dict, List, Optional - from fastapi import APIRouter, HTTPException, Query from fastapi.responses import FileResponse +from pathlib import Path from pydantic import BaseModel, Field # Import shared instances from shared import config_manager, project_discovery, session_manager +from typing import Any, Dict, List, Optional router = APIRouter() diff --git a/webui/backend/config_manager.py b/webui/backend/config_manager.py index eddf94915..8d9c7f623 100644 --- a/webui/backend/config_manager.py +++ b/webui/backend/config_manager.py @@ -3,12 +3,11 @@ Configuration management for MS-Agent Web UI Handles global settings, LLM configuration, and MCP server configuration. """ +import json import os from threading import Lock from typing import Any, Dict, Optional -import json - class ConfigManager: """Manages global configuration for the Web UI""" diff --git a/webui/backend/deep_research_eventizer.py b/webui/backend/deep_research_eventizer.py index 84949674a..f1d5d6d74 100644 --- a/webui/backend/deep_research_eventizer.py +++ b/webui/backend/deep_research_eventizer.py @@ -1,6 +1,6 @@ +import json from typing import Any, Callable, Dict, List, Optional -import json from ms_agent.llm.utils import Message, ToolCall diff --git a/webui/backend/deep_research_worker.py b/webui/backend/deep_research_worker.py index e8a3480a3..5a0ce2d80 100644 --- a/webui/backend/deep_research_worker.py +++ b/webui/backend/deep_research_worker.py @@ -1,17 +1,17 @@ import argparse import asyncio +import json import os import signal import sys import traceback +from deep_research_eventizer import HistoryEventizer # noqa: E402 +from omegaconf import OmegaConf from pathlib import Path from typing import Any, Dict, Optional -import json -from deep_research_eventizer import HistoryEventizer # noqa: E402 from ms_agent.agent.loader import AgentLoader from ms_agent.tools.agent_tool import AgentTool -from omegaconf import OmegaConf BACKEND_DIR = Path(__file__).resolve().parent if str(BACKEND_DIR) not in sys.path: diff --git a/webui/backend/deep_research_worker_manager.py b/webui/backend/deep_research_worker_manager.py index a5eb29d05..30b666207 100644 --- a/webui/backend/deep_research_worker_manager.py +++ b/webui/backend/deep_research_worker_manager.py @@ -1,4 +1,5 @@ import asyncio +import json import os import signal import sys @@ -6,8 +7,6 @@ from pathlib import Path from typing import Any, Awaitable, Callable, Dict, Optional -import json - class DeepResearchWorkerManager: diff --git a/webui/backend/main.py b/webui/backend/main.py index 2afcebbe3..8eb584bc1 100644 --- a/webui/backend/main.py +++ b/webui/backend/main.py @@ -5,7 +5,6 @@ """ import os import sys - import uvicorn from api import router as api_router from fastapi import FastAPI diff --git a/webui/backend/shared.py b/webui/backend/shared.py index 1d3ce9754..fb7b7fe67 100644 --- a/webui/backend/shared.py +++ b/webui/backend/shared.py @@ -4,7 +4,6 @@ Ensures api.py and websocket_handler.py use the same manager instances. """ import os - from config_manager import ConfigManager from project_discovery import ProjectDiscovery from session_manager import SessionManager diff --git a/webui/backend/websocket_handler.py b/webui/backend/websocket_handler.py index 7b2830b17..549d1f41c 100644 --- a/webui/backend/websocket_handler.py +++ b/webui/backend/websocket_handler.py @@ -4,17 +4,16 @@ Handles agent execution, log streaming, and progress updates. """ import asyncio -import os -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Set - import json +import os from agent_runner import AgentRunner +from datetime import datetime from deep_research_worker_manager import DeepResearchWorkerManager from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from pathlib import Path # Import shared instances from shared import config_manager, project_discovery, session_manager +from typing import Any, Dict, Set router = APIRouter() @@ -188,6 +187,7 @@ async def start_agent(session_id: str, data: Dict[str, Any], if session_type == 'chat': # Create a virtual project for chat mode using the default agent.yaml import ms_agent + # Get ms_agent package installation path # Use __path__ which is always available for packages and gives real filesystem paths if hasattr(ms_agent, '__path__') and ms_agent.__path__: